mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
107 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 37103baa07 | |||
| 35c5d43255 | |||
| 95c1e32aa5 | |||
| e4db65a127 | |||
| 0053defa2e | |||
| fd5d8b3d5f | |||
| 5bf82f8229 | |||
| 5ca3920611 | |||
| 8bde9d0ab7 | |||
| abcbc16126 | |||
| e4fd30a8d4 | |||
| 5f759b1637 | |||
| 6a75b4761a | |||
| e5ade5565d | |||
| 0524551f52 | |||
| 862bc7ef85 | |||
| d38792d6e5 | |||
| db3cf0158c | |||
| 0535f2a59a | |||
| 2805ae347c | |||
| 28ef6fcd14 | |||
| 7fc7ec75bb | |||
| 87890cbf38 | |||
| 5326ffe77e | |||
| a1734cf575 | |||
| 82f300e880 | |||
| 3e7c9d7afc | |||
| e9cb779eab | |||
| 8ff95be04c | |||
| f02ce69df0 | |||
| 1feb7b5d88 | |||
| fbe9009db2 | |||
| c0013b130b | |||
| c4763f61a1 | |||
| b95c219d96 | |||
| 9b1138171e | |||
| 023b8f3466 | |||
| 1cad87ebd2 | |||
| 99de7567e6 | |||
| 21baa8fa02 | |||
| 8b4a5368b3 | |||
| f5c6b03b61 | |||
| e7be2fd113 | |||
| b632490b4b | |||
| 9a9c7208d2 | |||
| 427b97d198 | |||
| 2c2bb1e8bf | |||
| 4b24f94225 | |||
| 670a278cbc | |||
| fc74001202 | |||
| f14ac5d486 | |||
| 7bd0d62ce5 | |||
| 7eccefe235 | |||
| b72274066e | |||
| 20f2910b63 | |||
| fd4ae3466b | |||
| 7beb040e8e | |||
| 05bd18f453 | |||
| 8077456c00 | |||
| 5595887fd0 | |||
| 41959389b6 | |||
| 2c4e888c7f | |||
| 5ced72e6b8 | |||
| 907023f9f7 | |||
| 4ba23ea029 | |||
| 409ac0baca | |||
| 699363f9fc | |||
| ae7a54de57 | |||
| fb9139b882 | |||
| 9fe3a3fb17 | |||
| 26cb9a24c3 | |||
| 77106697c3 | |||
| 75bc44c166 | |||
| f2b79656eb | |||
| 14c2ece004 | |||
| 35612c61e1 | |||
| f7bb3e2d90 | |||
| 1e0d667a22 | |||
| 33969a0337 | |||
| fa26290e8c | |||
| e9f7f5127b | |||
| 097842c70f | |||
| 3b8a3a32a0 | |||
| 1c56779dd9 | |||
| 83a4338f8b | |||
| 730c7b2f35 | |||
| 116059a43e | |||
| b08149a113 | |||
| c227107f60 | |||
| 01dc289f3d | |||
| 6830ca7645 | |||
| ed42c71fc3 | |||
| e0139065bd | |||
| e509f255af | |||
| e2fcd140b0 | |||
| 2a7a0e6129 | |||
| 9f33791b19 | |||
| 453e0a995f | |||
| 8ebf79c494 | |||
| 8774aec304 | |||
| ac742c9f0d | |||
| cd13f1ecfd | |||
| 9aa632968f | |||
| 62caaf07b0 | |||
| 3355f04ca6 | |||
| 769f531603 | |||
| f6c7287ae7 |
+382
-56
@@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
|
||||
|
||||
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
|
||||
|
||||
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
It combines three key ingredients:
|
||||
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
|
||||
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
@@ -56,30 +62,243 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class GymManipulatorConfig:
|
||||
env: HILSerlRobotEnvConfig # Environment configuration (nested)
|
||||
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
|
||||
mode: str | None = None # "record", "replay", or None (for training)
|
||||
device: str = "cpu" # Compute device
|
||||
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
|
||||
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
|
||||
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: str | None = None # LeRobot dataset repository ID
|
||||
dataset_root: str | None = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: str | None = None # For policy loading
|
||||
reward_classifier_pretrained_path: str | None = None # For reward model
|
||||
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
|
||||
task: str | None = None # Task identifier
|
||||
fps: int = 10 # Control frequency
|
||||
|
||||
# Nested processor configuration
|
||||
class HILSerlProcessorConfig:
|
||||
control_mode: str = "gamepad" # Control mode
|
||||
observation: ObservationConfig | None = None # Observation processing settings
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
|
||||
gripper: GripperConfig | None = None # Gripper control and penalty settings
|
||||
reset: ResetConfig | None = None # Environment reset and timing settings
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
|
||||
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
|
||||
max_gripper_pos: float | None = 100.0 # Maximum gripper position
|
||||
|
||||
# Sub-configuration classes
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
|
||||
resize_size: tuple[int, int] | None = None # Target image size
|
||||
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
reset_time_s: float = 5.0 # Time to wait during reset
|
||||
control_time_s: float = 20.0 # Maximum episode duration
|
||||
terminate_on_success: bool = True # Whether to terminate episodes on success detection
|
||||
|
||||
class InverseKinematicsConfig:
|
||||
urdf_path: str | None = None # Path to robot URDF file
|
||||
target_frame_name: str | None = None # End-effector frame name
|
||||
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
|
||||
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
|
||||
|
||||
class RewardClassifierConfig:
|
||||
pretrained_path: str | None = None # Path to pretrained reward classifier
|
||||
success_threshold: float = 0.5 # Success detection threshold
|
||||
success_reward: float = 1.0 # Reward value for successful episodes
|
||||
|
||||
# Dataset configuration
|
||||
class DatasetConfig:
|
||||
repo_id: str # LeRobot dataset repository ID
|
||||
dataset_root: str # Local dataset root directory
|
||||
task: str # Task identifier
|
||||
num_episodes: int # Number of episodes for recording
|
||||
episode: int # Episode index for replay
|
||||
push_to_hub: bool # Whether to push datasets to Hub
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Processor Pipeline Architecture
|
||||
|
||||
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
|
||||
|
||||
#### Environment Processor Pipeline
|
||||
|
||||
The environment processor (`env_processor`) handles incoming observations and environment state:
|
||||
|
||||
1. **VanillaObservationProcessor**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessor** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessor** (optional): Adds motor current readings to observations
|
||||
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
|
||||
5. **ImageCropResizeProcessor** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessor** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessor** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessor** (optional): Automated reward detection using vision models
|
||||
9. **ToBatchProcessor**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessor**: Moves data to the specified compute device (CPU/GPU)
|
||||
|
||||
#### Action Processor Pipeline
|
||||
|
||||
The action processor (`action_processor`) handles outgoing actions and human interventions:
|
||||
|
||||
1. **AddTeleopActionAsComplimentaryData**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfo**: Records intervention events and episode control signals
|
||||
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
|
||||
4. **InterventionActionProcessor**: Handles human interventions and episode termination
|
||||
5. **Inverse Kinematics Pipeline** (when enabled):
|
||||
- **MapDeltaActionToRobotAction**: Converts delta actions to robot action format
|
||||
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
|
||||
- **EEBoundsAndSafety**: Enforces workspace safety bounds
|
||||
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
|
||||
- **GripperVelocityToJoint**: Handles gripper control commands
|
||||
|
||||
#### Configuration Examples
|
||||
|
||||
**Basic Observation Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Image Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.front": [180, 250, 120, 150],
|
||||
"observation.images.side": [180, 207, 180, 200]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Inverse Kinematics Setup**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"urdf_path": "path/to/robot.urdf",
|
||||
"target_frame_name": "end_effector",
|
||||
"end_effector_bounds": {
|
||||
"min": [0.16, -0.08, 0.03],
|
||||
"max": [0.24, 0.2, 0.1]
|
||||
},
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Observation Processing
|
||||
|
||||
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
|
||||
|
||||
#### Joint Velocity Processing
|
||||
|
||||
Enable joint velocity estimation to provide the policy with motion information:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Estimates joint velocities using finite differences between consecutive joint position readings
|
||||
- Adds velocity information to the observation state vector
|
||||
- Useful for policies that need motion awareness for dynamic tasks
|
||||
|
||||
#### Motor Current Processing
|
||||
|
||||
Monitor motor currents to detect contact forces and load conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_current_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Reads motor current values from the robot's control system
|
||||
- Adds current measurements to the observation state vector
|
||||
- Helps detect contact events, object weights, and mechanical resistance
|
||||
- Useful for contact-rich manipulation tasks
|
||||
|
||||
#### Combined Observation Processing
|
||||
|
||||
You can enable multiple observation processing features simultaneously:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
|
||||
|
||||
### Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
@@ -130,22 +349,56 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
|
||||
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
1. Set `mode` to `"record"` at the root level
|
||||
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
|
||||
3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect
|
||||
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
|
||||
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
|
||||
|
||||
Example configuration section:
|
||||
|
||||
```json
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"observation": {
|
||||
"display_cameras": false
|
||||
},
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {},
|
||||
"resize_size": [128, 128]
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": 0.0
|
||||
},
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
### Using a Teleoperation Device
|
||||
@@ -191,10 +444,20 @@ The gamepad provides a very convenient way to control the robot and the episode
|
||||
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -216,11 +479,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
|
||||
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921", # check your port number
|
||||
"use_degrees": true
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921",
|
||||
"use_degrees": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "leader",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
|
||||
@@ -251,7 +524,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
|
||||
|
||||
During recording:
|
||||
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
|
||||
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
|
||||
2. Complete the task successfully
|
||||
3. The episode ends with a reward of 1 when you press the "success" button
|
||||
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
@@ -310,11 +583,19 @@ observation.images.front: [180, 250, 120, 150]
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Recommended image resolution**
|
||||
@@ -343,26 +624,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
|
||||
|
||||
**Key Parameters for Data Collection**
|
||||
|
||||
- **mode**: set it to `"record"` to collect a dataset
|
||||
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
- **mode**: set it to `"record"` to collect a dataset (at root level)
|
||||
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **dataset.num_episodes**: Number of episodes to record
|
||||
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
|
||||
- **env.fps**: Number of frames per second to record
|
||||
- **dataset.push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
|
||||
|
||||
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0,
|
||||
"terminate_on_success": false
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes": 20,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -421,9 +728,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
config = GymManipulatorConfig(
|
||||
env=HILSerlRobotEnvConfig(
|
||||
processor=HILSerlProcessorConfig(
|
||||
reward_classifier=RewardClassifierConfig(
|
||||
pretrained_path="path_to_your_pretrained_trained_model"
|
||||
)
|
||||
),
|
||||
# Other environment parameters
|
||||
),
|
||||
dataset=DatasetConfig(...),
|
||||
mode=None # For training
|
||||
)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
@@ -432,7 +747,18 @@ or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
"env": {
|
||||
"processor": {
|
||||
"reward_classifier": {
|
||||
"pretrained_path": "path_to_your_pretrained_model",
|
||||
"success_threshold": 0.7,
|
||||
"success_reward": 1.0
|
||||
},
|
||||
"reset": {
|
||||
"terminate_on_success": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
+56
-30
@@ -32,9 +32,12 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
@@ -45,28 +48,40 @@ Available tasks:
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### Gym Wrappers Configuration
|
||||
### Processor Configuration
|
||||
|
||||
```json
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": -0.02
|
||||
},
|
||||
"reset": {
|
||||
"control_time_s": 15.0,
|
||||
"fixed_reset_joint_positions": [
|
||||
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
|
||||
]
|
||||
},
|
||||
"inverse_kinematics": {
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `gripper.use_gripper`: Whether to enable gripper control
|
||||
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
|
||||
|
||||
## Running with HIL RL of LeRobot
|
||||
@@ -75,39 +90,50 @@ Important parameters:
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0"
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes": 10,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Training a Policy
|
||||
|
||||
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
|
||||
@@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
@@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig(
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/eval_<dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -559,6 +562,12 @@ _init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
@@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES):
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
+53
-5
@@ -24,11 +24,36 @@ pip install -e ".[hilserl]"
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes": 30,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
@@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
|
||||
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -11,12 +12,14 @@ NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<eval_dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -43,6 +46,12 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
@@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -38,7 +38,7 @@ while True:
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
log_rerun_data(observation, {**arm_action, **base_action})
|
||||
log_rerun_data(observation=observation, action={**arm_action, **base_action})
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot with degrees
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -0,0 +1,215 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 10
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = merge_features(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -0,0 +1,106 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import to_output_robot_action
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
|
||||
# This method converts the action from the dataset to a transition for pipeline
|
||||
def action_to_transition(action: dict):
|
||||
act = {}
|
||||
|
||||
# EE pose
|
||||
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
|
||||
if k in action:
|
||||
act[f"action.{k}"] = float(action[k])
|
||||
|
||||
# Gripper: your dataset has absolute position
|
||||
if "gripper.pos" in action:
|
||||
act["action.gripper.pos"] = float(action["gripper.pos"])
|
||||
|
||||
return {
|
||||
"observation": None,
|
||||
"action": act,
|
||||
"reward": None,
|
||||
"done": False,
|
||||
"truncated": False,
|
||||
"info": {},
|
||||
"complementary_data": {},
|
||||
}
|
||||
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specif
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
phone_obs = teleop_device.get_action()
|
||||
if not phone_obs:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone to EE pose transition
|
||||
ee_transition = phone_to_robot_ee_pose(phone_obs)
|
||||
|
||||
# EE pose to Joints transition
|
||||
joint_action = robot_ee_to_joints(ee_transition)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
time.sleep(0.01)
|
||||
+5
-2
@@ -73,6 +73,7 @@ dependencies = [
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"scipy>=1.15.2",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -95,7 +96,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
||||
transformers-dep = ["transformers<=4.52.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
@@ -111,6 +112,7 @@ intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -152,7 +154,8 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]"
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@@ -24,6 +24,7 @@ class FeatureType(str, Enum):
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
LANGUAGE = "LANGUAGE"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -21,6 +21,7 @@ OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
|
||||
@@ -39,6 +40,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: RobotProcessor,
|
||||
initial_features: dict[str, Any],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith("action."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len("action.") :]
|
||||
hw.setdefault("action", {})[name] = ty
|
||||
|
||||
elif full_key.startswith("observation.state."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len("observation.state.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith("observation.images."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len("observation.images.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if "action" in hw:
|
||||
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
return out
|
||||
@@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def merge_features(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
|
||||
+57
-86
@@ -161,35 +161,73 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
class RewardClassifierConfig:
|
||||
"""Configuration for reward classification."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationConfig:
|
||||
"""Configuration for observation processing."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GripperConfig:
|
||||
"""Configuration for gripper control and penalties."""
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetConfig:
|
||||
"""Configuration for environment reset behavior."""
|
||||
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
control_time_s: float = 20.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlProcessorConfig:
|
||||
"""Configuration for environment processing pipeline."""
|
||||
|
||||
control_mode: str = "gamepad"
|
||||
observation: ObservationConfig | None = None
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None
|
||||
gripper: GripperConfig | None = None
|
||||
reset: ResetConfig | None = None
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
fps: int = 10
|
||||
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
||||
|
||||
name: str = "real_robot"
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str | None = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
name: str = "PandaPickCube"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
@@ -15,6 +15,17 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
]
|
||||
|
||||
@@ -35,7 +35,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
@@ -137,23 +125,19 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
@@ -303,7 +287,7 @@ class ACT(nn.Module):
|
||||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: ACTConfig):
|
||||
def __init__(self, config: ACTConfig, dataset_stats=None):
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
super().__init__()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_act_processor(
|
||||
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -35,7 +35,6 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
@@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
@@ -106,9 +96,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -137,7 +124,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
@@ -153,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_diffusion_processor(
|
||||
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -14,9 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
@@ -34,9 +39,10 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
@@ -101,6 +107,123 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
|
||||
|
||||
def make_processor(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
),
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if policy_cfg.type == "tdmpc":
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||
|
||||
processors = make_tdmpc_processor(
|
||||
config=cast(TDMPCConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "diffusion":
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||
|
||||
processors = make_diffusion_processor(
|
||||
cast(DiffusionConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "act":
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
|
||||
processors = make_act_processor(
|
||||
config=cast(ACTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "vqbet":
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||
|
||||
processors = make_vqbet_processor(
|
||||
config=cast(VQBeTConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi0":
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
|
||||
|
||||
processors = make_pi0_processor(
|
||||
config=cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi0fast":
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
|
||||
|
||||
processors = make_pi0fast_processor(
|
||||
cast(PI0Config, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "sac":
|
||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||
|
||||
processors = make_sac_processor(
|
||||
cast(SACConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "reward_classifier":
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
cast(RewardClassifierConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "smolvla":
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
|
||||
|
||||
processors = make_smolvla_processor(
|
||||
cast(SmolVLAConfig, policy_cfg), dataset_stats=kwargs.get("dataset_stats")
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
return processors
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
@@ -147,7 +270,6 @@ def make_policy(
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
@@ -155,6 +277,8 @@ def make_policy(
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
if env_cfg is None:
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
|
||||
@@ -1,420 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
@@ -56,18 +56,15 @@ from collections import deque
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -223,28 +220,17 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -253,99 +239,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -377,14 +270,13 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
@@ -394,8 +286,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -410,12 +300,10 @@ class PI0Policy(PreTrainedPolicy):
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
@@ -482,26 +370,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
@@ -567,7 +435,7 @@ class PI0FlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PI0Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Check if complementary_data exists
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None or "task" not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
|
||||
def make_pi0_processor(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessor(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -58,7 +58,6 @@ from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -146,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
@@ -221,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
@@ -235,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -249,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_processor(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -28,7 +28,6 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.policies.normalize import NormalizeBuffer
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -45,7 +44,6 @@ class SACPolicy(
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -53,7 +51,6 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -88,8 +85,7 @@ class SACPolicy(
|
||||
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -391,28 +387,12 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
if self.config.dataset_stats is not None:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = NormalizeBuffer(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = NormalizeBuffer(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
@@ -424,9 +404,7 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
@@ -434,9 +412,7 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
@@ -490,10 +466,9 @@ class SACPolicy(
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
@@ -568,11 +543,10 @@ class SACObservationEncoder(nn.Module):
|
||||
def forward(
|
||||
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
cache = self.get_cached_image_features(obs)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
@@ -585,7 +559,7 @@ class SACObservationEncoder(nn.Module):
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
@@ -597,26 +571,17 @@ class SACObservationEncoder(nn.Module):
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in select_action()
|
||||
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
- Called internally by forward()
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
if normalize:
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
@@ -747,7 +712,6 @@ class CriticEnsemble(nn.Module):
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
output_normalization (nn.Module): normalization layer for actions.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
@@ -757,13 +721,11 @@ class CriticEnsemble(nn.Module):
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
@@ -775,11 +737,6 @@ class CriticEnsemble(nn.Module):
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_sac_processor(
|
||||
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -20,7 +20,6 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
@@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
@@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
IdentityProcessor,
|
||||
NormalizerProcessor,
|
||||
RobotProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
NormalizerProcessor(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
NormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
|
||||
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="classifier_postprocessor"
|
||||
)
|
||||
@@ -53,21 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
@@ -76,102 +68,6 @@ from lerobot.policies.utils import (
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
normalisation-buffer key.
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
• If several variant keys collapse to the same canonical name we keep the
|
||||
first one and log the collision.
|
||||
• Returns the new dict + a list of entries that could not be matched.
|
||||
"""
|
||||
out, collisions, unmatched = {}, {}, []
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
canon = canonicalise(k)
|
||||
if canon in ref_keys:
|
||||
if canon in out: # duplicate after collapsing
|
||||
collisions.setdefault(canon, []).append(k)
|
||||
else:
|
||||
out[canon] = v
|
||||
else:
|
||||
unmatched.append(k)
|
||||
|
||||
if verbose:
|
||||
for canon, variants in collisions.items():
|
||||
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
if unmatched:
|
||||
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): The checkpoint dictionary.
|
||||
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
Returns:
|
||||
dict: The modified checkpoint with renamed keys.
|
||||
"""
|
||||
|
||||
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
new_checkpoint = {}
|
||||
for k, v in checkpoint.items():
|
||||
for old_key, new_key in rename_dict.items():
|
||||
if old_key in k:
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
*,
|
||||
device: str = "cpu",
|
||||
checkpoint_keys_mapping: str = "",
|
||||
) -> torch.nn.Module:
|
||||
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
len(unexpected),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -326,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
@@ -357,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls,
|
||||
model: "SmolVLAPolicy",
|
||||
model_file: str,
|
||||
map_location: str,
|
||||
strict: bool,
|
||||
):
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return load_smolvla(
|
||||
model,
|
||||
model_file,
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -389,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
@@ -397,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -408,8 +275,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
@@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def make_smolvla_processor(
|
||||
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
SmolVLANewLineProcessor(),
|
||||
TokenizerProcessor(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Check if complementary_data exists
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None or "task" not in complementary_data:
|
||||
return transition
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Adds nothing to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
@@ -36,7 +36,6 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
@@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
@@ -137,7 +129,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -147,11 +138,12 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -320,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_tdmpc_processor(
|
||||
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -28,7 +28,6 @@ import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.reset()
|
||||
@@ -128,7 +118,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -142,10 +131,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -165,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_vqbet_processor(
|
||||
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
@@ -14,8 +14,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .batch_processor import ToBatchProcessor
|
||||
from .delta_action_processor import MapDeltaActionToRobotAction
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryData,
|
||||
AddTeleopEventsAsInfo,
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
Numpy2TorchActionProcessor,
|
||||
RewardClassifierProcessor,
|
||||
TimeLimitProcessor,
|
||||
Torch2NumpyActionProcessor,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
@@ -32,22 +46,39 @@ from .pipeline import (
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
from .tokenizer_processor import TokenizerProcessor
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"AddTeleopActionAsComplimentaryData",
|
||||
"AddTeleopEventsAsInfo",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"EnvTransition",
|
||||
"GripperPenaltyProcessor",
|
||||
"IdentityProcessor",
|
||||
"ImageCropResizeProcessor",
|
||||
"InfoProcessor",
|
||||
"InterventionActionProcessor",
|
||||
"JointVelocityProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"MotorCurrentProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"hotswap_stats",
|
||||
"ObservationProcessor",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
"RewardClassifierProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"ToBatchProcessor",
|
||||
"TokenizerProcessor",
|
||||
"TimeLimitProcessor",
|
||||
"Numpy2TorchActionProcessor",
|
||||
"Torch2NumpyActionProcessor",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class ToBatchProcessor:
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For task field in complementary data:
|
||||
Wraps string task in a list to add batch dimension
|
||||
(task must be a string or list of strings)
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
|
||||
The processor only modifies tensors that need batching and leaves already
|
||||
batched tensors unchanged.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
self._process_observation(transition)
|
||||
self._process_action(transition)
|
||||
self._process_complementary_data(transition)
|
||||
return transition
|
||||
|
||||
def _process_observation(self, transition: EnvTransition) -> None:
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return
|
||||
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
def _process_action(self, transition: EnvTransition) -> None:
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, Tensor) and action.dim() == 1:
|
||||
transition[TransitionKey.ACTION] = action.unsqueeze(0)
|
||||
|
||||
def _process_complementary_data(self, transition: EnvTransition) -> None:
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return
|
||||
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,225 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from .pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if isinstance(x, np.ndarray):
|
||||
# Keep images (uint8 HWC) and python objects as-is
|
||||
if x.dtype == np.uint8 or x.dtype == np.object_:
|
||||
return x
|
||||
# Scalars/arrays to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
# Anything else to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
|
||||
|
||||
def _from_tensor(x: Any):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
return state, images
|
||||
|
||||
|
||||
def make_obs_act_transition(
|
||||
*, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None
|
||||
) -> EnvTransition:
|
||||
return {
|
||||
TransitionKey.OBSERVATION: {} if obs is None else obs,
|
||||
TransitionKey.ACTION: {} if act is None else act,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
TransitionKey.REWARD: None,
|
||||
TransitionKey.DONE: None,
|
||||
TransitionKey.TRUNCATED: None,
|
||||
}
|
||||
|
||||
|
||||
def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"action.{k}"] = v
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"action.{k}"] = _to_tensor(arr)
|
||||
|
||||
return make_obs_act_transition(act=act_dict)
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"observation.images.{cam}"] = img
|
||||
|
||||
return make_obs_act_transition(obs=obs_dict)
|
||||
|
||||
|
||||
def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len("action.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, any]:
|
||||
"""
|
||||
Converts a single EnvTransition or an iterable of them into a flat,
|
||||
dataset-friendly dictionary for training or evaluation, according to
|
||||
the provided `features` spec.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict
|
||||
or an iterable of them (which will be merged).
|
||||
features (dict[str, dict]):
|
||||
A feature specification dictionary:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through
|
||||
|
||||
Returns:
|
||||
batch (dict[str, any]): Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action"
|
||||
- any image tensors defined in features
|
||||
- next.{reward,done,truncated}
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
"""
|
||||
action_names = features.get("action", {}).get("names", [])
|
||||
obs_state_names = features.get("observation.state", {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith("observation.images.")]
|
||||
|
||||
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
out = deepcopy(base)
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
def _ensure_transition(obj) -> EnvTransition:
|
||||
# single transition
|
||||
if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj):
|
||||
return obj
|
||||
# iterable of transitions
|
||||
if isinstance(obj, Iterable):
|
||||
items = list(obj)
|
||||
if not items:
|
||||
return {}
|
||||
acc = items[0]
|
||||
for t in items[1:]:
|
||||
acc = _merge(acc, t)
|
||||
return acc
|
||||
raise TypeError("Expected EnvTransition or iterable of them")
|
||||
|
||||
tr = _ensure_transition(transitions_or_transition)
|
||||
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, any] = {}
|
||||
|
||||
# Images passthrough
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
|
||||
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
|
||||
batch["action"] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Next.* fields
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
return batch
|
||||
@@ -0,0 +1,125 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
gripper_deadzone: float = 0.1 # Threshold for gripper activation
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, action: dict | Tensor | None) -> dict:
|
||||
if action is None:
|
||||
return {}
|
||||
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
if isinstance(action, dict):
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
else:
|
||||
delta_x = action[0].item()
|
||||
delta_y = action[1].item()
|
||||
delta_z = action[2].item()
|
||||
gripper = action[3].item()
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z)
|
||||
enabled = position_magnitude > 1e-6 # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = float(delta_x) * self.position_scale
|
||||
scaled_delta_y = float(delta_y) * self.position_scale
|
||||
scaled_delta_z = float(delta_z) * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": scaled_delta_x,
|
||||
"action.target_y": scaled_delta_y,
|
||||
"action.target_z": scaled_delta_z,
|
||||
"action.target_wx": target_wx,
|
||||
"action.target_wy": target_wy,
|
||||
"action.target_wz": target_wz,
|
||||
"action.gripper": float(gripper),
|
||||
}
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
# Update features to reflect the new action format
|
||||
features.update(
|
||||
{
|
||||
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
)
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
@@ -19,24 +19,80 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessor:
|
||||
"""Processes transitions by moving tensors to the specified device.
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
specified device (CPU or GPU) before they are returned. It can also convert
|
||||
floating-point tensors to a specified dtype while preserving non-float types
|
||||
(int, long, bool, etc.).
|
||||
"""
|
||||
|
||||
device: torch.device = "cpu"
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
_device: torch.device | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self._device = get_safe_torch_device(self.device)
|
||||
self.device = self._device.type
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
if self.float_dtype is not None:
|
||||
dtype_mapping = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"half": torch.float16,
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
}
|
||||
|
||||
if self.float_dtype not in dtype_mapping:
|
||||
available_dtypes = list(dtype_mapping.keys())
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = dtype_mapping[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self._device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self._device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified and tensor is floating point
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=self._target_float_dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
@@ -45,7 +101,7 @@ class DeviceProcessor:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
@@ -53,30 +109,54 @@ class DeviceProcessor:
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
|
||||
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
|
||||
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
new_transition[TransitionKey.DONE] = self._process_tensor(done)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
|
||||
|
||||
# Process complementary data tensors
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is not None:
|
||||
new_complementary_data = {}
|
||||
|
||||
# Process all items in complementary_data
|
||||
for key, value in complementary_data.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_complementary_data[key] = self._process_tensor(value)
|
||||
else:
|
||||
new_complementary_data[key] = value
|
||||
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device}
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict | None) -> dict:
|
||||
complementary_data = {} if complementary_data is None else dict(complementary_data)
|
||||
complementary_data["teleop_action"] = self.teleop_device.get_action()
|
||||
return complementary_data
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfo(InfoProcessor):
|
||||
"""Add teleoperator control events to transition info."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def info(self, info: dict | None) -> dict:
|
||||
info = {} if info is None else dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
info.update(teleop_events)
|
||||
return info
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor | None) -> np.ndarray | None:
|
||||
if action is None:
|
||||
return None
|
||||
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray | None) -> torch.Tensor | None:
|
||||
if action is None:
|
||||
return None
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = torch.from_numpy(action)
|
||||
return torch_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessor(ObservationProcessor):
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if self.crop_params_dict is not None and key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
if self.resize_size is not None:
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
if "image" in key:
|
||||
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessor:
|
||||
"""Track episode steps and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return transition
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessor:
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
if complementary_data is None or action is None:
|
||||
return transition
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return transition
|
||||
|
||||
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Add penalty information to complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||
|
||||
# Create new transition with updated complementary data
|
||||
new_transition = transition.copy()
|
||||
existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
existing_comp_data.update(new_complementary_data)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc]
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessor:
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
teleop_action = info.get("teleop_action", {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action is not None:
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get("action.delta_x", 0.0),
|
||||
teleop_action.get("action.delta_y", 0.0),
|
||||
teleop_action.get("action.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get("gripper", 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
action_list = teleop_action
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
|
||||
self.terminate_on_success and success
|
||||
)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info[TeleopEvents.IS_INTERVENTION] = is_intervention
|
||||
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
|
||||
info[TeleopEvents.SUCCESS] = success
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessor:
|
||||
"""Apply reward classification to image observations."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
|
||||
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if success == 1.0:
|
||||
reward = self.success_reward
|
||||
if self.terminate_on_success:
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
# Update info with classifier frequency
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["reward_classifier_frequency"] = classifier_frequency
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,116 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor:
|
||||
"""Add joint velocity information to observations."""
|
||||
|
||||
joint_velocity_limits: float = 100.0
|
||||
dt: float = 1.0 / 10
|
||||
num_dof: int | None = None
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
if current_positions is None:
|
||||
return observation
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"joint_velocity_limits": self.joint_velocity_limits,
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features and self.num_dof is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessor(ObservationProcessor):
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
return observation
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get("observation.state")
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features and self.robot is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
@@ -0,0 +1,502 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
|
||||
|
||||
This script:
|
||||
1. Loads an existing pretrained policy model
|
||||
2. Extracts normalization statistics from the model
|
||||
3. Creates both preprocessor and postprocessor:
|
||||
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||
4. Removes normalization layers from the model state_dict
|
||||
5. Saves the new model and both processors
|
||||
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--policy-type act \
|
||||
--push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||
from lerobot.processor.device_processor import DeviceProcessor
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
||||
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
||||
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
||||
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
||||
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
||||
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
||||
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
||||
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
||||
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
||||
}
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
stats = {}
|
||||
|
||||
# Define patterns to match and their prefixes to remove
|
||||
normalization_patterns = [
|
||||
"normalize_inputs.buffer_",
|
||||
"unnormalize_outputs.buffer_",
|
||||
"normalize_targets.buffer_",
|
||||
"normalize.", # Must come after normalize_* patterns
|
||||
"unnormalize.", # Must come after unnormalize_* patterns
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
]
|
||||
|
||||
# Process each key in state_dict
|
||||
for key, tensor in state_dict.items():
|
||||
# Try each pattern
|
||||
for pattern in normalization_patterns:
|
||||
if key.startswith(pattern):
|
||||
# Extract the remaining part after the pattern
|
||||
remaining = key[len(pattern) :]
|
||||
parts = remaining.split(".")
|
||||
|
||||
# Need at least feature name and stat type
|
||||
if len(parts) >= 2:
|
||||
# Last part is the stat type (mean, std, min, max, etc.)
|
||||
stat_type = parts[-1]
|
||||
# Everything else is the feature name
|
||||
feature_name = ".".join(parts[:-1]).replace("_", ".")
|
||||
|
||||
# Add to stats
|
||||
if feature_name not in stats:
|
||||
stats[feature_name] = {}
|
||||
stats[feature_name][stat_type] = tensor.clone()
|
||||
|
||||
# Only process the first matching pattern
|
||||
break
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def detect_features_and_norm_modes(
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
|
||||
# First, check if there's a normalization_mapping in the config
|
||||
if "normalization_mapping" in config:
|
||||
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
||||
# Extract normalization modes from config
|
||||
for feature_name, mode_str in config["normalization_mapping"].items():
|
||||
# Convert string to NormalizationMode enum
|
||||
if mode_str == "mean_std":
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "min_max":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
else:
|
||||
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
||||
continue
|
||||
|
||||
# Determine feature type from feature name
|
||||
if "image" in feature_name or "visual" in feature_name:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in feature_name:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in feature_name:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
norm_modes[feature_type] = mode
|
||||
|
||||
# Try to extract from config
|
||||
if "features" in config:
|
||||
for key, feature_config in config["features"].items():
|
||||
shape = feature_config.get("shape", feature_config.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
# Determine feature type
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If no features in config, infer from stats
|
||||
if not features:
|
||||
for key, stat_dict in stats.items():
|
||||
# Get shape from any stat tensor
|
||||
tensor = next(iter(stat_dict.values()))
|
||||
shape = tuple(tensor.shape)
|
||||
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key or "pixels" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If normalization modes weren't in config, determine based on available stats
|
||||
if not norm_modes:
|
||||
for key, stat_dict in stats.items():
|
||||
if key in features:
|
||||
if "mean" in stat_dict and "std" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MEAN_STD
|
||||
elif "min" in stat_dict and "max" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MIN_MAX
|
||||
|
||||
# Default normalization modes if not detected
|
||||
if FeatureType.VISUAL not in norm_modes:
|
||||
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
|
||||
if FeatureType.STATE not in norm_modes:
|
||||
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
|
||||
if FeatureType.ACTION not in norm_modes:
|
||||
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
|
||||
|
||||
return features, norm_modes
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
new_state_dict = {}
|
||||
|
||||
# Patterns to remove
|
||||
remove_patterns = [
|
||||
"normalize_inputs.",
|
||||
"unnormalize_outputs.",
|
||||
"normalize_targets.", # Added pattern for target normalization
|
||||
"normalize.",
|
||||
"unnormalize.",
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
"normalizer.",
|
||||
]
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
should_remove = any(pattern in key for pattern in remove_patterns)
|
||||
if not should_remove:
|
||||
new_state_dict[key] = tensor
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert features from old format to PolicyFeature objects."""
|
||||
converted_features = {}
|
||||
|
||||
for key, feature_dict in features_dict.items():
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
# Get shape from feature dict
|
||||
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
return converted_features
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
|
||||
# Load config
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate policy models with normalization layers to new pipeline system"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model (hub repo or local directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for migrated model (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
|
||||
parser.add_argument(
|
||||
"--hub-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hub repository ID for pushing (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and config
|
||||
print(f"Loading model from {args.pretrained_path}...")
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
# Local directory
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
|
||||
# Extract normalization statistics
|
||||
print("Extracting normalization statistics...")
|
||||
stats = extract_normalization_stats(state_dict)
|
||||
|
||||
print(f"Found normalization statistics for: {list(stats.keys())}")
|
||||
|
||||
# Detect input features and normalization modes
|
||||
print("Detecting features and normalization modes...")
|
||||
features, norm_map = detect_features_and_norm_modes(config, stats)
|
||||
|
||||
print(f"Detected features: {list(features.keys())}")
|
||||
print(f"Normalization modes: {norm_map}")
|
||||
|
||||
# Remove normalization layers from state_dict
|
||||
print("Removing normalization layers from model...")
|
||||
new_state_dict = remove_normalization_layers(state_dict)
|
||||
|
||||
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
||||
if removed_keys:
|
||||
print(f"Removed {len(removed_keys)} normalization layer keys")
|
||||
|
||||
# Determine output path
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
else:
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
|
||||
else:
|
||||
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Clean up config - remove normalization_mapping field
|
||||
cleaned_config = dict(config)
|
||||
if "normalization_mapping" in cleaned_config:
|
||||
print("Removing 'normalization_mapping' field from config")
|
||||
del cleaned_config["normalization_mapping"]
|
||||
policy_type = deepcopy(cleaned_config["type"])
|
||||
|
||||
del cleaned_config["type"]
|
||||
|
||||
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
||||
print(f"Instantiating {policy_type} policy model...")
|
||||
policy_class_path = POLICY_CLASSES[policy_type]
|
||||
module_path, class_name = policy_class_path.rsplit(".", 1)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_class = getattr(module, class_name)
|
||||
|
||||
# Create config class instance
|
||||
config_module_path = module_path.replace("modeling", "configuration")
|
||||
config_module = importlib.import_module(config_module_path)
|
||||
# Handle special cases for config class names
|
||||
config_class_names = {
|
||||
"act": "ACTConfig",
|
||||
"diffusion": "DiffusionConfig",
|
||||
"pi0": "PI0Config",
|
||||
"pi0fast": "PI0FASTConfig",
|
||||
"smolvla": "SmolVLAConfig",
|
||||
"tdmpc": "TDMPCConfig",
|
||||
"vqbet": "VQBeTConfig",
|
||||
"sac": "SACConfig",
|
||||
"classifier": "ClassifierConfig",
|
||||
}
|
||||
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
||||
config_class = getattr(config_module, config_class_name)
|
||||
|
||||
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
||||
if "input_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'input_features' in config")
|
||||
if "output_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'output_features' in config")
|
||||
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
||||
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
||||
|
||||
# Create config instance from cleaned config dict
|
||||
policy_config = config_class(**cleaned_config)
|
||||
|
||||
# Create policy instance - some policies expect dataset_stats
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
|
||||
# Now create preprocessor and postprocessor with cleaned_config available
|
||||
print("Creating preprocessor and postprocessor...")
|
||||
# The pattern from existing processor factories:
|
||||
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessor for output_features only
|
||||
|
||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||
input_features = cleaned_config.get("input_features", {})
|
||||
output_features = cleaned_config.get("output_features", {})
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=policy_config.device),
|
||||
]
|
||||
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||
]
|
||||
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
if args.hub_repo_id:
|
||||
hub_repo_id = args.hub_repo_id
|
||||
else:
|
||||
if not os.path.isdir(args.pretrained_path):
|
||||
# Use same repo with "_migrated" suffix
|
||||
hub_repo_id = f"{args.pretrained_path}_migrated"
|
||||
else:
|
||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||
else:
|
||||
hub_repo_id = None
|
||||
|
||||
# Save preprocessor and postprocessor to root directory
|
||||
print(f"Saving preprocessor to {output_dir}...")
|
||||
preprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
print(f"Saving postprocessor to {output_dir}...")
|
||||
postprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
# Save model using the policy's save_pretrained method
|
||||
print(f"Saving model to {output_dir}...")
|
||||
policy.save_pretrained(
|
||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||
)
|
||||
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
tags = set(tags).union({"robotics", "lerobot", policy_type})
|
||||
tags = list(tags)
|
||||
|
||||
# Generate model card
|
||||
card = policy.generate_model_card(
|
||||
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
|
||||
)
|
||||
|
||||
# Save model card locally
|
||||
card.save(str(output_dir / "README.md"))
|
||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||
# Push model card to hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_dir / "README.md"),
|
||||
path_in_repo="README.md",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Add model card for migrated model",
|
||||
)
|
||||
print("Model card pushed to hub")
|
||||
|
||||
print("\nMigration complete!")
|
||||
print(f"Migrated model saved to: {output_dir}")
|
||||
if args.push_to_hub:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -10,7 +11,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
@@ -115,7 +116,7 @@ class NormalizerProcessor:
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
|
||||
def _normalize_obs(self, observation):
|
||||
def _normalize_obs(self, observation, normalized_info):
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
@@ -128,7 +129,20 @@ class NormalizerProcessor:
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
if key not in processed or key not in self.features:
|
||||
continue
|
||||
|
||||
# Check the normalization mode for this feature type
|
||||
feature = self.features[key]
|
||||
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip normalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
normalized_info[key] = "IDENTITY"
|
||||
continue
|
||||
|
||||
# Skip if no stats available for this key
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
@@ -139,16 +153,35 @@ class NormalizerProcessor:
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
normalized_info[key] = "MEAN_STD"
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
normalized_info[key] = "MIN_MAX"
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
def _normalize_action(self, action, normalized_info):
|
||||
if action is None:
|
||||
return action
|
||||
|
||||
# Check the normalization mode for actions
|
||||
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip normalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
normalized_info["action"] = "IDENTITY"
|
||||
return action
|
||||
|
||||
# Skip if no stats available for actions
|
||||
if "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
@@ -157,22 +190,42 @@ class NormalizerProcessor:
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
normalized_info["action"] = "MEAN_STD"
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
normalized_info["action"] = "MIN_MAX"
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# If we reach here, the required stats for the normalization mode are not available
|
||||
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
# Track what was normalized
|
||||
normalized_info = {}
|
||||
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Add normalization info to complementary data
|
||||
if normalized_info:
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data["normalized_keys"] = normalized_info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
@@ -204,7 +257,7 @@ class NormalizerProcessor:
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -253,14 +306,28 @@ class UnnormalizerProcessor:
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation):
|
||||
def _unnormalize_obs(self, observation, unnormalized_info):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
if key not in processed or key not in self.features:
|
||||
continue
|
||||
|
||||
# Check the normalization mode for this feature type
|
||||
feature = self.features[key]
|
||||
norm_mode = self.norm_map.get(feature.type, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip unnormalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
unnormalized_info[key] = "IDENTITY"
|
||||
continue
|
||||
|
||||
# Skip if no stats available for this key
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
@@ -268,39 +335,80 @@ class UnnormalizerProcessor:
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
unnormalized_info[key] = "MEAN_STD"
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
unnormalized_info[key] = "MIN_MAX"
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
def _unnormalize_action(self, action, unnormalized_info):
|
||||
if action is None:
|
||||
return action
|
||||
|
||||
# Check the normalization mode for actions
|
||||
norm_mode = self.norm_map.get(FeatureType.ACTION, NormalizationMode.IDENTITY)
|
||||
|
||||
# Skip unnormalization if mode is IDENTITY
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
unnormalized_info["action"] = "IDENTITY"
|
||||
return action
|
||||
|
||||
# Skip if no stats available for actions
|
||||
if "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
unnormalized_info["action"] = "MEAN_STD"
|
||||
return tensor * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
unnormalized_info["action"] = "MIN_MAX"
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
else:
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# If we reach here, the required stats for the normalization mode are not available
|
||||
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
# Track what was unnormalized
|
||||
unnormalized_info = {}
|
||||
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Add unnormalization info to complementary data
|
||||
if unnormalized_info:
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data["unnormalized_keys"] = unnormalized_info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
@@ -327,5 +435,41 @@ class UnnormalizerProcessor:
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||
robot_processor = deepcopy(robot_processor)
|
||||
for step in robot_processor.steps:
|
||||
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
|
||||
step: NormalizerProcessor | UnnormalizerProcessor
|
||||
step.stats = stats
|
||||
step._tensor_stats = _convert_stats_to_tensors(stats)
|
||||
return robot_processor
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to the provided mapping.
|
||||
|
||||
Args:
|
||||
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
|
||||
rename_map: Dictionary mapping old key names to new key names
|
||||
|
||||
Returns:
|
||||
A new stats dictionary with renamed keys
|
||||
|
||||
Example:
|
||||
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
|
||||
>>> rename_map = {"observation.state": "observation.robot_state"}
|
||||
>>> new_stats = rename_stats(stats, rename_map)
|
||||
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
|
||||
"""
|
||||
renamed_stats = {}
|
||||
|
||||
for old_key, sub_stats in stats.items():
|
||||
# Use the new key if it exists in the rename map, otherwise keep the old key
|
||||
new_key = rename_map.get(old_key, old_key)
|
||||
renamed_stats[new_key] = deepcopy(sub_stats)
|
||||
|
||||
return renamed_stats
|
||||
|
||||
@@ -106,9 +106,8 @@ class VanillaObservationProcessor(ObservationProcessor):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
|
||||
@@ -23,7 +23,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
@@ -132,6 +132,7 @@ class ProcessorStepRegistry:
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessorStep(Protocol):
|
||||
"""Structural typing interface for a single processor step.
|
||||
|
||||
@@ -145,7 +146,6 @@ class ProcessorStep(Protocol):
|
||||
|
||||
**Required**:
|
||||
- ``__call__(transition: EnvTransition) -> EnvTransition``
|
||||
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||||
|
||||
Optional helper protocol:
|
||||
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||||
@@ -158,6 +158,8 @@ class ProcessorStep(Protocol):
|
||||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||||
containing torch tensors only.
|
||||
* ``reset()`` – Clear internal buffers at episode boundaries.
|
||||
* ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||||
If present, this method will be called to aggregate the dataset features of all steps.
|
||||
|
||||
Example separation:
|
||||
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
|
||||
@@ -174,7 +176,7 @@ class ProcessorStep(Protocol):
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
@@ -201,10 +203,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
|
||||
# Extract padding and task keys for complementary data
|
||||
# Extract padding, task, index, and task_index keys for complementary data
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
complementary_data = (
|
||||
{**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
if pad_keys or task_key or index_key or task_index_key
|
||||
else {}
|
||||
)
|
||||
|
||||
transition: EnvTransition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
@@ -231,7 +239,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add padding and task data from complementary_data
|
||||
# Add padding, task, index, and task_index data from complementary_data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
@@ -240,6 +248,12 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
if "task" in complementary_data:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
if "index" in complementary_data:
|
||||
batch["index"] = complementary_data["index"]
|
||||
|
||||
if "task_index" in complementary_data:
|
||||
batch["task_index"] = complementary_data["task_index"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
@@ -342,7 +356,10 @@ class RobotProcessor(ModelHubMixin):
|
||||
hook(idx, current_transition)
|
||||
|
||||
# Convert back to original format if needed
|
||||
return self.to_output(current_transition) if called_with_batch else current_transition
|
||||
if called_with_batch or self.to_output is not _default_transition_to_batch:
|
||||
return self.to_output(current_transition)
|
||||
else:
|
||||
return current_transition
|
||||
|
||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||
"""Prepare and validate transition data for processing.
|
||||
@@ -575,10 +592,9 @@ class RobotProcessor(ModelHubMixin):
|
||||
if config_filename is None:
|
||||
# Try common config names
|
||||
common_names = [
|
||||
"processor.json",
|
||||
"preprocessor.json",
|
||||
"postprocessor.json",
|
||||
"robotprocessor.json",
|
||||
"robot_processor.json",
|
||||
"robot_preprocessor.json",
|
||||
"robot_postprocessor.json",
|
||||
]
|
||||
config_path = None
|
||||
for name in common_names:
|
||||
@@ -808,23 +824,15 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||
)
|
||||
|
||||
fc = getattr(step, "feature_contract", None)
|
||||
if not callable(fc):
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
|
||||
)
|
||||
|
||||
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Apply ALL steps in order. Each step must implement
|
||||
feature_contract(features) and return a dict (full or incremental schema).
|
||||
Apply ALL steps in order. Only if a step has a features method, it will be called.
|
||||
We aggregate the dataset features of all steps.
|
||||
"""
|
||||
features: dict[str, PolicyFeature] = deepcopy(initial_features)
|
||||
|
||||
for _, step in enumerate(self.steps):
|
||||
out = step.feature_contract(features)
|
||||
if not isinstance(out, dict):
|
||||
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
|
||||
out = step.transform_features(features)
|
||||
features = out
|
||||
return features
|
||||
|
||||
@@ -884,7 +892,7 @@ class ObservationProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -944,7 +952,7 @@ class ActionProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1003,7 +1011,7 @@ class RewardProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1067,7 +1075,7 @@ class DoneProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1127,7 +1135,7 @@ class TruncatedProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1192,7 +1200,7 @@ class InfoProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1238,7 +1246,7 @@ class ComplementaryDataProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1260,5 +1268,5 @@ class IdentityProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -43,7 +43,7 @@ class RenameProcessor(ObservationProcessor):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Tokenizer processor for handling text tokenization in robot transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessor:
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
|
||||
to the observation data for model processing while preserving the original task string.
|
||||
|
||||
The processor supports both single strings and lists of strings as task inputs.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
|
||||
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
|
||||
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
|
||||
tokenizer: A tokenizer object (e.g., from transformers library) that implements
|
||||
the __call__ method. If provided, tokenizer_name is ignored. This parameter
|
||||
is not serialized and must be provided via overrides when loading.
|
||||
max_length: Maximum sequence length for tokenization. Defaults to 512.
|
||||
task_key: Key in complementary_data containing the task text. Defaults to "task".
|
||||
padding: Padding strategy for tokenization. Defaults to "max_length".
|
||||
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
|
||||
|
||||
Examples:
|
||||
Using tokenizer name (auto-loaded):
|
||||
```python
|
||||
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
|
||||
```
|
||||
|
||||
Using custom tokenizer object:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
|
||||
```
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
|
||||
# Internal tokenizer instance (not serialized)
|
||||
_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self._tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
def get_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""Extract and normalize task from complementary data.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data.
|
||||
|
||||
Returns:
|
||||
List of task strings if task is present, None otherwise.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
if self.task_key not in complementary_data:
|
||||
return None
|
||||
|
||||
task = complementary_data[self.task_key]
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Convert to list of strings
|
||||
if isinstance(task, str):
|
||||
return [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
return task
|
||||
|
||||
return None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process the transition by tokenizing the task text.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data with task text.
|
||||
|
||||
Returns:
|
||||
Modified transition with tokenized task added to observation.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
"""
|
||||
task = self.get_task(transition)
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
target_device = self._detect_device(transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
observation = {}
|
||||
else:
|
||||
observation = dict(observation) # Make a copy
|
||||
|
||||
# Add tokenized data to observation
|
||||
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
||||
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
|
||||
return transition
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
# Check other tensor fields
|
||||
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check complementary data for tensors
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
for value in complementary_data.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
|
||||
"""
|
||||
return self._tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=self.truncation,
|
||||
padding=self.padding,
|
||||
padding_side=self.padding_side,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization.
|
||||
|
||||
Note: Only tokenizer_name is saved, not the tokenizer object itself.
|
||||
When loading, provide the tokenizer via overrides if needed.
|
||||
"""
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
if self.tokenizer_name is not None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
|
||||
Args:
|
||||
features: Input feature dictionary.
|
||||
|
||||
Returns:
|
||||
Updated feature dictionary with tokenized task features added.
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
tokens_key = f"{OBS_LANGUAGE}.tokens"
|
||||
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
|
||||
|
||||
if tokens_key not in features:
|
||||
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if attention_mask_key not in features:
|
||||
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
return features
|
||||
+149
-36
@@ -59,7 +59,7 @@ lerobot-record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
@@ -72,10 +72,19 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import rename_stats
|
||||
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -149,6 +158,8 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
@@ -187,6 +198,36 @@ class RecordConfig:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
[ Robot ]
|
||||
V
|
||||
[ robot.get_observation() ] ---> raw_obs
|
||||
V
|
||||
[ robot_observation_processor ] ---> obs_transition
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> teleop_transition '---> policy_transition
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
[ robot.send_action() ] -- (Robot Executes)
|
||||
V
|
||||
( Transitions are merged & added to Dataset )
|
||||
V
|
||||
( Rerun Log / Loop Wait )
|
||||
"""
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_loop(
|
||||
robot: Robot,
|
||||
@@ -195,15 +236,30 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
control_time_s: int | None = None,
|
||||
teleop_action_processor: RobotProcessor | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessor | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessor | None = None, # runs after robot
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list):
|
||||
if isinstance(teleop, list): # For LeKiwi
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
teleop_arm = next(
|
||||
(
|
||||
@@ -219,9 +275,20 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset custom pipelines
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
policy_transition = None
|
||||
teleop_transition = None
|
||||
obs_transition = None
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -232,51 +299,87 @@ def record_loop(
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
observation = robot.get_observation()
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = to_dataset_frame(
|
||||
obs_transition, dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
|
||||
if policy is not None:
|
||||
action_values = predict_action(
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
|
||||
policy_transition = {
|
||||
TransitionKey.ACTION: policy_action,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
elif isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen during environment reset."
|
||||
)
|
||||
continue
|
||||
# Still continue to next loop to respect timing
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy_transition is not None:
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
sent_action = robot.send_action(action)
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
# If to_dataset_frame is provided, use it to merge the transitions.
|
||||
merged = []
|
||||
if obs_transition is not None: # The observation from the robot
|
||||
merged.append(obs_transition)
|
||||
if teleop_transition is not None: # The action from teleop
|
||||
merged.append(teleop_transition)
|
||||
if policy_transition is not None: # The action from policy
|
||||
merged.append(policy_transition)
|
||||
frame = to_dataset_frame(
|
||||
merged if len(merged) > 1 else merged[0], dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(observation, action)
|
||||
log_rerun_data([obs_transition, teleop_transition or policy_transition])
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -328,6 +431,18 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -345,6 +460,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
@@ -393,9 +510,5 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
record()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
record()
|
||||
|
||||
@@ -14,6 +14,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .so100_follower import SO100Follower
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
@@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Path to URDF file for kinematics
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
urdf_path: str | None = None
|
||||
|
||||
# End-effector frame name in the URDF
|
||||
target_frame_name: str = "gripper_frame_link"
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,465 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta:
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: dict
|
||||
motor_names: list[str]
|
||||
use_latched_reference: bool = (
|
||||
True # If True, latch reference on enable; if False, always use current pose
|
||||
)
|
||||
|
||||
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if "reference_joint_positions" in comp:
|
||||
q = comp["reference_joint_positions"]
|
||||
else:
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(act.pop("action.enabled", 0))
|
||||
tx = float(act.pop("action.target_x", 0.0))
|
||||
ty = float(act.pop("action.target_y", 0.0))
|
||||
tz = float(act.pop("action.target_z", 0.0))
|
||||
wx = float(act.pop("action.target_wx", 0.0))
|
||||
wy = float(act.pop("action.target_wy", 0.0))
|
||||
wz = float(act.pop("action.target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
if enabled:
|
||||
ref = t_curr
|
||||
if self.use_latched_reference:
|
||||
# Latched reference mode: latch reference at the rising edge
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
|
||||
delta_p = np.array(
|
||||
[
|
||||
tx * self.end_effector_step_sizes["x"],
|
||||
ty * self.end_effector_step_sizes["y"],
|
||||
tz * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
|
||||
self._command_when_disabled = desired.copy()
|
||||
else:
|
||||
# While disabled, keep sending the same command to avoid drift.
|
||||
if self._command_when_disabled is None:
|
||||
# If we've never had an enabled command yet, freeze current FK pose once.
|
||||
self._command_when_disabled = t_curr.copy()
|
||||
desired = self._command_when_disabled.copy()
|
||||
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": float(pos[0]),
|
||||
"action.ee.y": float(pos[1]),
|
||||
"action.ee.z": float(pos[2]),
|
||||
"action.ee.wx": float(tw[0]),
|
||||
"action.ee.wy": float(tw[1]),
|
||||
"action.ee.wz": float(tw[2]),
|
||||
}
|
||||
)
|
||||
|
||||
self._prev_enabled = enabled
|
||||
transition[TransitionKey.ACTION] = act
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessor):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
max_ee_step_m: float = 0.05
|
||||
max_ee_twist_step_rad: float = 0.20
|
||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict | None) -> dict:
|
||||
x = act.pop("action.ee.x", None)
|
||||
y = act.pop("action.ee.y", None)
|
||||
z = act.pop("action.ee.z", None)
|
||||
wx = act.pop("action.ee.wx", None)
|
||||
wy = act.pop("action.ee.wy", None)
|
||||
wz = act.pop("action.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return act
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
|
||||
# Clip position
|
||||
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
|
||||
|
||||
# Check for jumps in position
|
||||
if self._last_pos is not None:
|
||||
dpos = pos - self._last_pos
|
||||
n = float(np.linalg.norm(dpos))
|
||||
if n > self.max_ee_step_m and n > 0:
|
||||
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
|
||||
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
|
||||
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": float(pos[0]),
|
||||
"action.ee.y": float(pos[1]),
|
||||
"action.ee.z": float(pos[2]),
|
||||
"action.ee.wx": float(twist[0]),
|
||||
"action.ee.wy": float(twist[1]),
|
||||
"action.ee.wz": float(twist[2]),
|
||||
}
|
||||
)
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
self._last_pos = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints:
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get("action.ee.x", None)
|
||||
y = act.get("action.ee.y", None)
|
||||
z = act.get("action.ee.z", None)
|
||||
wx = act.get("action.ee.wx", None)
|
||||
wy = act.get("action.ee.wy", None)
|
||||
wz = act.get("action.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
# Nothing to do; restore what we popped and return
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": x,
|
||||
"action.ee.y": y,
|
||||
"action.ee.z": z,
|
||||
"action.ee.wx": wx,
|
||||
"action.ee.wy": wy,
|
||||
"action.ee.wz": wz,
|
||||
}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = act
|
||||
return transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"action.{name}.pos"] = float(q_target[i])
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
|
||||
features["observation.state.gripper.pos"] = float
|
||||
features["action.gripper.pos"] = float
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint:
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if "action.gripper" not in act:
|
||||
return transition
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
new_act = dict(act)
|
||||
new_act.pop("action.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get("action.gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act["action.gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get("action.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act["action.gripper.pos"] = gripper_pos
|
||||
new_act.pop("action.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs.update({"observation.state.gripper.pos": curr_pos})
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["observation.state.gripper.pos"] = float
|
||||
features["action.gripper.pos"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, obs: dict | None) -> dict:
|
||||
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
|
||||
return obs
|
||||
|
||||
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
obs.update(
|
||||
{
|
||||
"observation.state.ee.x": float(pos[0]),
|
||||
"observation.state.ee.y": float(pos[1]),
|
||||
"observation.state.ee.z": float(pos[2]),
|
||||
"observation.state.ee.wx": float(tw[0]),
|
||||
"observation.state.ee.wy": float(tw[1]),
|
||||
"observation.state.ee.wz": float(tw[2]),
|
||||
}
|
||||
)
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"observation.state.ee.{k}"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_robot_observation")
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
|
||||
def complementary_data(self, comp: dict | None) -> dict:
|
||||
comp = {} if comp is None else dict(comp)
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
comp["raw_joint_positions"] = {
|
||||
k.removesuffix(".pos"): float(v)
|
||||
for k, v in obs.items()
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return comp
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -1,200 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.feetech import FeetechMotorsBus
|
||||
|
||||
from . import SO100Follower
|
||||
from .config_so100_follower import SO100FollowerEndEffectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO100FollowerEndEffector(SO100Follower):
|
||||
"""
|
||||
SO100Follower robot with end-effector space control.
|
||||
|
||||
This robot inherits from SO100Follower but transforms actions from
|
||||
end-effector space to joint space before sending them to the motors.
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerEndEffectorConfig
|
||||
name = "so100_follower_end_effector"
|
||||
|
||||
def __init__(self, config: SO100FollowerEndEffectorConfig):
|
||||
super().__init__(config)
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.config = config
|
||||
|
||||
# Initialize the kinematics module for the so100 robot
|
||||
if self.config.urdf_path is None:
|
||||
raise ValueError(
|
||||
"urdf_path must be provided in the configuration for end-effector control. "
|
||||
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
|
||||
)
|
||||
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=self.config.urdf_path,
|
||||
target_frame_name=self.config.target_frame_name,
|
||||
)
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, Any]:
|
||||
"""
|
||||
Define action features for end-effector control.
|
||||
Returns dictionary with dtype, shape, and names.
|
||||
"""
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform action from end-effector space to joint space and send to motors.
|
||||
|
||||
Args:
|
||||
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
|
||||
or a numpy array with [delta_x, delta_y, delta_z]
|
||||
|
||||
Returns:
|
||||
The joint-space action that was sent to the motors
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Convert action to numpy array if not already
|
||||
if isinstance(action, dict):
|
||||
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action["delta_x"] * self.config.end_effector_step_sizes["x"],
|
||||
action["delta_y"] * self.config.end_effector_step_sizes["y"],
|
||||
action["delta_z"] * self.config.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
if "gripper" not in action:
|
||||
action["gripper"] = [1.0]
|
||||
action = np.append(delta_ee, action["gripper"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
|
||||
)
|
||||
action = np.zeros(4, dtype=np.float32)
|
||||
|
||||
if self.current_joint_pos is None:
|
||||
# Read current joint positions
|
||||
current_joint_pos = self.bus.sync_read("Present_Position")
|
||||
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
|
||||
self.current_joint_pos, desired_ee_pos
|
||||
)
|
||||
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
# Gripper delta action is in the range 0 - 2,
|
||||
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values_in_degrees.copy()
|
||||
self.current_joint_pos[-1] = joint_action["gripper.pos"]
|
||||
|
||||
# Send joint space action to parent class
|
||||
return super().send_action(joint_action)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
@@ -69,6 +69,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
raise ValueError(config.type)
|
||||
|
||||
|
||||
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
) -> dict[str, float]:
|
||||
|
||||
@@ -62,9 +62,16 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
@@ -236,7 +243,8 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -257,6 +265,12 @@ def act_with_policy(
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -274,45 +288,61 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
observation = transition[TransitionKey.OBSERVATION]
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# Use the new step function
|
||||
new_transition = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = new_transition[TransitionKey.OBSERVATION]
|
||||
executed_action = new_transition[TransitionKey.ACTION]
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
complementary_info = {
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
@@ -347,12 +377,20 @@ def act_with_policy(
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
@@ -1048,10 +1049,8 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
@@ -1176,7 +1175,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
|
||||
@@ -26,12 +26,13 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -140,6 +141,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
@@ -149,6 +153,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
|
||||
postprocessor.from_pretrained(
|
||||
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -203,12 +211,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
@@ -240,7 +245,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
save_checkpoint(
|
||||
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
@@ -284,6 +291,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any, Callable
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
is_launched_with_accelerate,
|
||||
)
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator: Callable = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
policy.train()
|
||||
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
accelerator.backward(loss)
|
||||
accelerator.unscale_gradients(optimizer=optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Callable):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
set_seed(cfg.seed, accelerator=accelerator)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
policy.to(device)
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
if accelerator.is_main_process:
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
initial_step=step,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and accelerator.is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps and accelerator.is_main_process
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and accelerator.is_main_process
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
cfg,
|
||||
accelerator.unwrap_model(policy),
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
env=eval_env,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
eval_metrics,
|
||||
initial_step=step,
|
||||
accelerator=None,
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||
# adjusting the lr_scheduler steps based on the num_processes
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
||||
train(accelerator=accelerator)
|
||||
@@ -109,7 +109,7 @@ def teleop_loop(
|
||||
action = teleop.get_action()
|
||||
if display_data:
|
||||
observation = robot.get_observation()
|
||||
log_rerun_data(observation, action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
robot.send_action(action)
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
from .utils import make_teleoperator_from_config
|
||||
from .utils import TeleopEvents, make_teleoperator_from_config
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -134,10 +136,10 @@ class KeyboardController(InputController):
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -255,13 +257,13 @@ class GamepadController(InputController):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
@@ -451,11 +453,11 @@ class GamepadControllerHID(InputController):
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
|
||||
|
||||
@@ -93,9 +94,9 @@ class GamepadTeleop(Teleoperator):
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
action_dict = {
|
||||
"delta_x": gamepad_action[0],
|
||||
"delta_y": gamepad_action[1],
|
||||
"delta_z": gamepad_action[2],
|
||||
"action.delta_x": gamepad_action[0],
|
||||
"action.delta_y": gamepad_action[1],
|
||||
"action.delta_z": gamepad_action[2],
|
||||
}
|
||||
|
||||
# Default gripper action is to stay
|
||||
@@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the gamepad such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if self.gamepad is None:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Update gamepad state to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
# Check if intervention is active
|
||||
is_intervention = self.gamepad.should_intervene()
|
||||
|
||||
# Get episode end status
|
||||
episode_end_status = self.gamepad.get_episode_end_status()
|
||||
terminate_episode = episode_end_status in [
|
||||
TeleopEvents.RERECORD_EPISODE,
|
||||
TeleopEvents.FAILURE,
|
||||
]
|
||||
success = episode_end_status == TeleopEvents.SUCCESS
|
||||
rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the gamepad."""
|
||||
if self.gamepad is not None:
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
@@ -167,13 +168,13 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2, "action.gripper": 3},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
|
||||
}
|
||||
|
||||
def _on_press(self, key):
|
||||
@@ -226,12 +227,75 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
self.current_pressed.clear()
|
||||
|
||||
action_dict = {
|
||||
"delta_x": delta_x,
|
||||
"delta_y": delta_y,
|
||||
"delta_z": delta_z,
|
||||
"action.delta_x": delta_x,
|
||||
"action.delta_y": delta_y,
|
||||
"action.delta_z": delta_z,
|
||||
}
|
||||
|
||||
if self.config.use_gripper:
|
||||
action_dict["gripper"] = gripper_action
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the keyboard such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Keyboard mappings:
|
||||
- Any movement keys pressed = intervention active
|
||||
- 's' key = success (terminate episode successfully)
|
||||
- 'r' key = rerecord episode (terminate and rerecord)
|
||||
- 'q' key = quit episode (terminate without success)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Check if any movement keys are currently pressed (indicates intervention)
|
||||
movement_keys = [
|
||||
keyboard.Key.up,
|
||||
keyboard.Key.down,
|
||||
keyboard.Key.left,
|
||||
keyboard.Key.right,
|
||||
keyboard.Key.shift,
|
||||
keyboard.Key.shift_r,
|
||||
keyboard.Key.ctrl_r,
|
||||
keyboard.Key.ctrl_l,
|
||||
]
|
||||
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
|
||||
|
||||
# Check for episode control commands from misc_keys_queue
|
||||
terminate_episode = False
|
||||
success = False
|
||||
rerecord_episode = False
|
||||
|
||||
# Process any pending misc keys
|
||||
while not self.misc_keys_queue.empty():
|
||||
key = self.misc_keys_queue.get_nowait()
|
||||
if key == "s":
|
||||
success = True
|
||||
elif key == "r":
|
||||
terminate_episode = True
|
||||
rerecord_episode = True
|
||||
elif key == "q":
|
||||
terminate_episode = True
|
||||
success = False
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_phone import PhoneConfig
|
||||
from .phone import Phone
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
class PhoneOS(Enum):
|
||||
ANDROID = "android"
|
||||
IOS = "ios"
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("phone")
|
||||
@dataclass
|
||||
class PhoneConfig(TeleoperatorConfig):
|
||||
phone_os: PhoneOS = PhoneOS.IOS
|
||||
camera_offset = np.array(
|
||||
[0.0, -0.02, 0.04]
|
||||
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
|
||||
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Docs:
|
||||
# hebi: https://docs.hebi.us/tools.html#mobile-io
|
||||
# teleop: https://github.com/SpesRobotics/teleop
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Phone(Teleoperator):
|
||||
"""
|
||||
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
|
||||
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
|
||||
|
||||
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
|
||||
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
|
||||
"""
|
||||
|
||||
config_class = PhoneConfig
|
||||
name = "phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
self._teleop = None
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
self._latest_message = None
|
||||
self._enabled: bool = False
|
||||
self._calib_pos: np.ndarray | None = None
|
||||
self._calib_rot_inv: Rotation | None = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or (
|
||||
self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None
|
||||
)
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
|
||||
if group is None:
|
||||
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
|
||||
self._group = group
|
||||
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
|
||||
elif self.config.phone_os == PhoneOS.ANDROID:
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
|
||||
self._teleop_thread.start()
|
||||
logger.info(f"{self} connected, teleop stream started.")
|
||||
else:
|
||||
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
else:
|
||||
print("Touch and move on the WebXR page to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
|
||||
self._calib_pos = pos.copy()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
|
||||
"phone.enabled": bool,
|
||||
}
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
ok, pos, rot, pose = self._read_current_pose()
|
||||
if not ok:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
io = getattr(pose, "io", None)
|
||||
b = getattr(io, "b", None) if io is not None else None
|
||||
b1 = False
|
||||
if b is not None:
|
||||
b1 = bool(b.get_int(1))
|
||||
if b1:
|
||||
return pos, rot
|
||||
else:
|
||||
msg = self._latest_message or {}
|
||||
if bool(msg.get("move", False)):
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
ar_quat = getattr(pose, "ar_orientation", None)
|
||||
if ar_pos is None or ar_quat is None:
|
||||
return False, None, None, None
|
||||
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
|
||||
rot = Rotation.from_quat(quat_xyzw)
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
else:
|
||||
p = self._latest_pose
|
||||
if p is None:
|
||||
return False, None, None, None
|
||||
rot = Rotation.from_matrix(p[:3, :3])
|
||||
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
|
||||
pose = self._latest_pose
|
||||
return True, pos, rot, pose
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
time.sleep(0.001) # 1ms delay to avoid race condition
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
io = getattr(pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
for ch in range(1, 9):
|
||||
if bank_a.has_float(ch):
|
||||
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
|
||||
if bank_b:
|
||||
for ch in range(1, 9):
|
||||
if bank_b.has_int(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
|
||||
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
|
||||
else:
|
||||
msg = self._latest_message or {}
|
||||
raw_inputs["move"] = bool(msg.get("move", False))
|
||||
raw_inputs["scale"] = float(msg.get("scale", 1.0))
|
||||
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
|
||||
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
enable = bool(raw_inputs.get("b1", 0))
|
||||
else:
|
||||
enable = bool(raw_inputs.get("move", False))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
self._group = None
|
||||
else:
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
@@ -0,0 +1,87 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
"""
|
||||
Map calibrated phone pose (actions) to the inputs for robot actions
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.phone.enabled": bool,
|
||||
"action.phone.pos": np.ndarray,
|
||||
"action.phone.rot": Rotation,
|
||||
"action.phone.raw_inputs": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict | None) -> dict:
|
||||
# Pop them from the action
|
||||
enabled = act.pop("action.phone.enabled", 0)
|
||||
pos = act.pop("action.phone.pos", None)
|
||||
rot = act.pop("action.phone.rot", None)
|
||||
inputs = act.pop("action.phone.raw_inputs", {})
|
||||
|
||||
if pos is None or rot is None:
|
||||
return act
|
||||
|
||||
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
|
||||
|
||||
# Map certain inputs to certain actions
|
||||
if self.platform == PhoneOS.IOS:
|
||||
gripper = float(inputs.get("a3", 0.0))
|
||||
else:
|
||||
a = float(inputs.get("reservedButtonA", 0.0))
|
||||
b = float(inputs.get("reservedButtonB", 0.0))
|
||||
gripper = (
|
||||
a - b
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act.update(
|
||||
{
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": -pos[1] if enabled else 0.0,
|
||||
"action.target_y": pos[0] if enabled else 0.0,
|
||||
"action.target_z": pos[2] if enabled else 0.0,
|
||||
"action.target_wx": rotvec[1] if enabled else 0.0,
|
||||
"action.target_wy": rotvec[0] if enabled else 0.0,
|
||||
"action.target_wz": -rotvec[2] if enabled else 0.0,
|
||||
"action.gripper": gripper, # Still send gripper action when disabled
|
||||
}
|
||||
)
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -12,10 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
|
||||
class TeleopEvents(Enum):
|
||||
"""Shared constants for teleoperator events across teleoperators."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
RERECORD_EPISODE = "rerecord_episode"
|
||||
IS_INTERVENTION = "is_intervention"
|
||||
TERMINATE_EPISODE = "terminate_episode"
|
||||
|
||||
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
if config.type == "keyboard":
|
||||
from .keyboard import KeyboardTeleop
|
||||
|
||||
@@ -31,6 +31,7 @@ from termcolor import colored
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@@ -101,6 +102,8 @@ def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
preprocessor: RobotProcessor,
|
||||
postprocessor: RobotProcessor,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
@@ -122,10 +125,14 @@ def predict_action(
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
|
||||
|
||||
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||
_transformers_available = is_package_available("transformers")
|
||||
_gym_xarm_available = is_package_available("gym_xarm")
|
||||
_gym_aloha_available = is_package_available("gym_aloha")
|
||||
_gym_pusht_available = is_package_available("gym_pusht")
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from lerobot.utils.utils import format_big_number
|
||||
|
||||
@@ -84,6 +84,7 @@ class MetricsTracker:
|
||||
"samples",
|
||||
"episodes",
|
||||
"epochs",
|
||||
"accelerator",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -93,12 +94,14 @@ class MetricsTracker:
|
||||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
self.metrics = metrics
|
||||
self.accelerator = accelerator
|
||||
|
||||
self.steps = initial_step
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
@@ -128,7 +131,7 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import random
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -164,7 +164,7 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
def set_seed(seed: int, accelerator: Callable | None = None) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -172,6 +172,11 @@ def set_seed(seed) -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if accelerator:
|
||||
from accelerate.utils import set_seed as accelerate_set_seed
|
||||
|
||||
accelerate_set_seed(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
|
||||
@@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
@@ -74,6 +75,8 @@ def save_checkpoint(
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -81,7 +84,9 @@ def save_checkpoint(
|
||||
├── pretrained_model/
|
||||
│ ├── config.json # policy config
|
||||
│ ├── model.safetensors # policy weights
|
||||
│ └── train_config.json # train config
|
||||
│ ├── train_config.json # train config
|
||||
│ ├── processor.json # processor config (if preprocessor provided)
|
||||
│ └── step_*.safetensors # processor state files (if any)
|
||||
└── training_state/
|
||||
├── optimizer_param_groups.json # optimizer param groups
|
||||
├── optimizer_state.safetensors # optimizer state
|
||||
@@ -95,10 +100,15 @@ def save_checkpoint(
|
||||
policy (PreTrainedPolicy): The policy to save.
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if preprocessor is not None:
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import time
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
@@ -56,13 +57,15 @@ def auto_select_torch_device() -> torch.device:
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
def get_safe_torch_device(
|
||||
try_device: str, log: bool = False, accelerator: Callable | None = None
|
||||
) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
device = accelerator.device if accelerator else torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
@@ -116,6 +119,7 @@ def init_logging(
|
||||
display_pid: bool = False,
|
||||
console_level: str = "INFO",
|
||||
file_level: str = "DEBUG",
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -152,6 +156,11 @@ def init_logging(
|
||||
file_handler.setLevel(file_level.upper())
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -165,6 +174,10 @@ def format_big_number(num, precision=0):
|
||||
return num
|
||||
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
|
||||
@@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
if val.ndim == 1:
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
|
||||
def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
|
||||
*,
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
items = data if isinstance(data, list) else ([data] if data is not None else [])
|
||||
|
||||
obs = {} if observation is None else dict(observation)
|
||||
act = {} if action is None else dict(action)
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if any(isinstance(k, TransitionKey) for k in item.keys()):
|
||||
o = item.get(TransitionKey.OBSERVATION) or {}
|
||||
a = item.get(TransitionKey.ACTION) or {}
|
||||
if isinstance(o, dict):
|
||||
obs.update(o)
|
||||
if isinstance(a, dict):
|
||||
act.update(a)
|
||||
continue
|
||||
|
||||
keys = list(item.keys())
|
||||
has_obs = any(str(k).startswith("observation.") for k in keys)
|
||||
has_act = any(str(k).startswith("action.") for k in keys)
|
||||
|
||||
if has_obs or has_act:
|
||||
if has_obs:
|
||||
obs.update(item)
|
||||
if has_act:
|
||||
act.update(item)
|
||||
else:
|
||||
# No prefixes: assume first is observation, second is action, others are observation
|
||||
if idx == 0:
|
||||
obs.update(item)
|
||||
elif idx == 1:
|
||||
act.update(item)
|
||||
else:
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
obs.update(item)
|
||||
|
||||
for k, v in obs.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
|
||||
for k, v in act.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
size 33400
|
||||
oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e
|
||||
size 31672
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c
|
||||
size 515400
|
||||
|
||||
+2
-2
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
size 33400
|
||||
oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
|
||||
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
|
||||
size 992
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
|
||||
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
|
||||
size 47424
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
|
||||
size 49120
|
||||
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
|
||||
size 47424
|
||||
|
||||
@@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_processor
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
@@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
dataset_stats = dataset.meta.stats
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
@@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
@@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
else:
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {}
|
||||
for i in range(actions_queue):
|
||||
unnormalized_action = policy.select_action(obs).contiguous()
|
||||
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
|
||||
actions[str(i)] = action_robot
|
||||
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
|
||||
oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61
|
||||
size 200
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
|
||||
size 16904
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
|
||||
size 164
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
||||
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
|
||||
size 35496
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
|
||||
oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d
|
||||
size 200
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
|
||||
size 16904
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
|
||||
size 164
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
||||
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
|
||||
size 35496
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_merge_simple_vectors():
|
||||
g1 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.x", "ee.y"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.y", "ee.z"],
|
||||
}
|
||||
}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
# Names merged with preserved order and de-dupuplication
|
||||
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
# Shape correctly recomputed from names length
|
||||
assert out["action"]["shape"] == (3,)
|
||||
|
||||
|
||||
def test_merge_multiple_groups_order_and_dedup():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = merge_features(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
|
||||
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = merge_features(g1, g2)
|
||||
|
||||
|
||||
def test_non_dict_passthrough_last_wins():
|
||||
g1 = {"misc": 123}
|
||||
g2 = {"misc": 456}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
# For non-dict entries the last one wins
|
||||
assert out["misc"] == 456
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
@@ -39,8 +39,8 @@ from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
make_processor,
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -151,6 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
preprocessor, _ = make_processor(train_cfg.policy, None)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
@@ -224,6 +225,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
preprocessor, _ = make_processor(cfg.policy, None)
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
@@ -263,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_features = {
|
||||
"observation.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 96, 96),
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
@@ -464,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
|
||||
and you changed your processor, you might have to update the test artifact.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for ACT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default ACT configuration for testing."""
|
||||
config = ACTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_act_processor_basic():
|
||||
"""Test basic creation of ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_act_processor_normalization():
|
||||
"""Test that ACT processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
assert postprocessed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_cuda():
|
||||
"""Test ACT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_accelerate_scenario():
|
||||
"""Test ACT processor in simulated Accelerate scenario (data already on GPU)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU (not moved unnecessarily)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_act_processor_multi_gpu():
|
||||
"""Test ACT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU (like in multi-GPU training)
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1 (not moved to cuda:0)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_act_processor_without_stats():
|
||||
"""Test ACT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors, but normalization won't have stats
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_act_processor_save_and_load():
|
||||
"""Test saving and loading ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
def test_act_processor_device_placement_preservation():
|
||||
"""Test that ACT processor preserves device placement correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_processor(config, stats)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_mixed_precision():
|
||||
"""Test ACT processor with mixed precision (float16)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_act_processor_batch_consistency():
|
||||
"""Test that ACT processor handles different batch sizes correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user