mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 37103baa07 | |||
| 35c5d43255 | |||
| 95c1e32aa5 | |||
| e4db65a127 | |||
| 0053defa2e | |||
| 0878c6880f | |||
| fd5d8b3d5f | |||
| 5bf82f8229 | |||
| 5ca3920611 | |||
| 8bde9d0ab7 | |||
| abcbc16126 | |||
| e4fd30a8d4 | |||
| 11e6bd762a | |||
| ce3b9f627e |
@@ -30,7 +30,7 @@ pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --some.option=true
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
|
||||
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
|
||||
|
||||
@@ -29,8 +29,8 @@ on:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||
|
||||
# Ensures that only the latest commit is built, canceling older runs.
|
||||
concurrency:
|
||||
|
||||
@@ -44,7 +44,7 @@ test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval
|
||||
|
||||
test-act-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--policy.dim_model=64 \
|
||||
--policy.n_action_steps=20 \
|
||||
@@ -68,12 +68,12 @@ test-act-ete-train:
|
||||
--output_dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-train-resume:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
|
||||
test-act-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
@@ -82,7 +82,7 @@ test-act-ete-eval:
|
||||
--eval.batch_size=1
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=diffusion \
|
||||
--policy.down_dims='[64,128,256]' \
|
||||
--policy.diffusion_step_embed_dim=32 \
|
||||
@@ -106,7 +106,7 @@ test-diffusion-ete-train:
|
||||
--output_dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=pusht \
|
||||
@@ -115,7 +115,7 @@ test-diffusion-ete-eval:
|
||||
--eval.batch_size=1
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=tdmpc \
|
||||
--policy.device=$(DEVICE) \
|
||||
--policy.push_to_hub=false \
|
||||
@@ -137,7 +137,7 @@ test-tdmpc-ete-train:
|
||||
--output_dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=xarm \
|
||||
@@ -148,7 +148,7 @@ test-tdmpc-ete-eval:
|
||||
|
||||
|
||||
test-smolvla-ete-train:
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.n_action_steps=20 \
|
||||
--policy.chunk_size=20 \
|
||||
@@ -171,7 +171,7 @@ test-smolvla-ete-train:
|
||||
--output_dir=tests/outputs/smolvla/
|
||||
|
||||
test-smolvla-ete-eval:
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \
|
||||
--policy.device=$(DEVICE) \
|
||||
--env.type=aloha \
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain)
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
@@ -276,7 +276,7 @@ Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
@@ -288,10 +288,10 @@ python -m lerobot.scripts.eval \
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
See `lerobot-eval --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
@@ -303,7 +303,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina
|
||||
|
||||
\<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/wandb.png" alt="WandB logs example"\>
|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions.
|
||||
|
||||
#### Reproduce state-of-the-art (SOTA)
|
||||
|
||||
@@ -311,7 +311,7 @@ We provide some pretrained policies on our [hub page](https://huggingface.co/ler
|
||||
You can reproduce their training by loading the config from their run. Simply running:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
@@ -9,7 +9,7 @@ To instantiate a camera, you need a camera identifier. This identifier might cha
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
```
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
+384
-58
@@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
|
||||
|
||||
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
|
||||
|
||||
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
It combines three key ingredients:
|
||||
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
|
||||
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
@@ -56,30 +62,243 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class GymManipulatorConfig:
|
||||
env: HILSerlRobotEnvConfig # Environment configuration (nested)
|
||||
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
|
||||
mode: str | None = None # "record", "replay", or None (for training)
|
||||
device: str = "cpu" # Compute device
|
||||
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
|
||||
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
|
||||
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: str | None = None # LeRobot dataset repository ID
|
||||
dataset_root: str | None = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: str | None = None # For policy loading
|
||||
reward_classifier_pretrained_path: str | None = None # For reward model
|
||||
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
|
||||
task: str | None = None # Task identifier
|
||||
fps: int = 10 # Control frequency
|
||||
|
||||
# Nested processor configuration
|
||||
class HILSerlProcessorConfig:
|
||||
control_mode: str = "gamepad" # Control mode
|
||||
observation: ObservationConfig | None = None # Observation processing settings
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
|
||||
gripper: GripperConfig | None = None # Gripper control and penalty settings
|
||||
reset: ResetConfig | None = None # Environment reset and timing settings
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
|
||||
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
|
||||
max_gripper_pos: float | None = 100.0 # Maximum gripper position
|
||||
|
||||
# Sub-configuration classes
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
|
||||
resize_size: tuple[int, int] | None = None # Target image size
|
||||
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
reset_time_s: float = 5.0 # Time to wait during reset
|
||||
control_time_s: float = 20.0 # Maximum episode duration
|
||||
terminate_on_success: bool = True # Whether to terminate episodes on success detection
|
||||
|
||||
class InverseKinematicsConfig:
|
||||
urdf_path: str | None = None # Path to robot URDF file
|
||||
target_frame_name: str | None = None # End-effector frame name
|
||||
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
|
||||
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
|
||||
|
||||
class RewardClassifierConfig:
|
||||
pretrained_path: str | None = None # Path to pretrained reward classifier
|
||||
success_threshold: float = 0.5 # Success detection threshold
|
||||
success_reward: float = 1.0 # Reward value for successful episodes
|
||||
|
||||
# Dataset configuration
|
||||
class DatasetConfig:
|
||||
repo_id: str # LeRobot dataset repository ID
|
||||
dataset_root: str # Local dataset root directory
|
||||
task: str # Task identifier
|
||||
num_episodes: int # Number of episodes for recording
|
||||
episode: int # Episode index for replay
|
||||
push_to_hub: bool # Whether to push datasets to Hub
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Processor Pipeline Architecture
|
||||
|
||||
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
|
||||
|
||||
#### Environment Processor Pipeline
|
||||
|
||||
The environment processor (`env_processor`) handles incoming observations and environment state:
|
||||
|
||||
1. **VanillaObservationProcessor**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessor** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessor** (optional): Adds motor current readings to observations
|
||||
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
|
||||
5. **ImageCropResizeProcessor** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessor** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessor** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessor** (optional): Automated reward detection using vision models
|
||||
9. **ToBatchProcessor**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessor**: Moves data to the specified compute device (CPU/GPU)
|
||||
|
||||
#### Action Processor Pipeline
|
||||
|
||||
The action processor (`action_processor`) handles outgoing actions and human interventions:
|
||||
|
||||
1. **AddTeleopActionAsComplimentaryData**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfo**: Records intervention events and episode control signals
|
||||
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
|
||||
4. **InterventionActionProcessor**: Handles human interventions and episode termination
|
||||
5. **Inverse Kinematics Pipeline** (when enabled):
|
||||
- **MapDeltaActionToRobotAction**: Converts delta actions to robot action format
|
||||
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
|
||||
- **EEBoundsAndSafety**: Enforces workspace safety bounds
|
||||
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
|
||||
- **GripperVelocityToJoint**: Handles gripper control commands
|
||||
|
||||
#### Configuration Examples
|
||||
|
||||
**Basic Observation Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Image Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.front": [180, 250, 120, 150],
|
||||
"observation.images.side": [180, 207, 180, 200]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Inverse Kinematics Setup**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"urdf_path": "path/to/robot.urdf",
|
||||
"target_frame_name": "end_effector",
|
||||
"end_effector_bounds": {
|
||||
"min": [0.16, -0.08, 0.03],
|
||||
"max": [0.24, 0.2, 0.1]
|
||||
},
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Observation Processing
|
||||
|
||||
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
|
||||
|
||||
#### Joint Velocity Processing
|
||||
|
||||
Enable joint velocity estimation to provide the policy with motion information:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Estimates joint velocities using finite differences between consecutive joint position readings
|
||||
- Adds velocity information to the observation state vector
|
||||
- Useful for policies that need motion awareness for dynamic tasks
|
||||
|
||||
#### Motor Current Processing
|
||||
|
||||
Monitor motor currents to detect contact forces and load conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_current_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Reads motor current values from the robot's control system
|
||||
- Adds current measurements to the observation state vector
|
||||
- Helps detect contact events, object weights, and mechanical resistance
|
||||
- Useful for contact-rich manipulation tasks
|
||||
|
||||
#### Combined Observation Processing
|
||||
|
||||
You can enable multiple observation processing features simultaneously:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
|
||||
|
||||
### Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
@@ -130,22 +349,56 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
|
||||
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
1. Set `mode` to `"record"` at the root level
|
||||
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
|
||||
3. Set `num_episodes` in the `dataset` section to the number of demonstrations you want to collect
|
||||
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
|
||||
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
|
||||
|
||||
Example configuration section:
|
||||
|
||||
```json
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"observation": {
|
||||
"display_cameras": false
|
||||
},
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {},
|
||||
"resize_size": [128, 128]
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": 0.0
|
||||
},
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
### Using a Teleoperation Device
|
||||
@@ -191,10 +444,20 @@ The gamepad provides a very convenient way to control the robot and the episode
|
||||
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -216,11 +479,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
|
||||
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921", # check your port number
|
||||
"use_degrees": true
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921",
|
||||
"use_degrees": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "leader",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
|
||||
@@ -251,7 +524,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
|
||||
|
||||
During recording:
|
||||
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
|
||||
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
|
||||
2. Complete the task successfully
|
||||
3. The episode ends with a reward of 1 when you press the "success" button
|
||||
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
@@ -310,11 +583,19 @@ observation.images.front: [180, 250, 120, 150]
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Recommended image resolution**
|
||||
@@ -343,26 +624,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
|
||||
|
||||
**Key Parameters for Data Collection**
|
||||
|
||||
- **mode**: set it to `"record"` to collect a dataset
|
||||
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
- **mode**: set it to `"record"` to collect a dataset (at root level)
|
||||
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **dataset.num_episodes**: Number of episodes to record
|
||||
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
|
||||
- **env.fps**: Number of frames per second to record
|
||||
- **dataset.push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
|
||||
|
||||
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0,
|
||||
"terminate_on_success": false
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes": 20,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -412,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
To train the classifier, use the `train.py` script with your configuration:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json
|
||||
lerobot-train --config_path path/to/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
**Deploying and Testing the Model**
|
||||
@@ -421,9 +728,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
config = GymManipulatorConfig(
|
||||
env=HILSerlRobotEnvConfig(
|
||||
processor=HILSerlProcessorConfig(
|
||||
reward_classifier=RewardClassifierConfig(
|
||||
pretrained_path="path_to_your_pretrained_trained_model"
|
||||
)
|
||||
),
|
||||
# Other environment parameters
|
||||
),
|
||||
dataset=DatasetConfig(...),
|
||||
mode=None # For training
|
||||
)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
@@ -432,7 +747,18 @@ or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
"env": {
|
||||
"processor": {
|
||||
"reward_classifier": {
|
||||
"pretrained_path": "path_to_your_pretrained_model",
|
||||
"success_threshold": 0.7,
|
||||
"success_reward": 1.0
|
||||
},
|
||||
"reset": {
|
||||
"terminate_on_success": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -458,7 +784,7 @@ The reward classifier will automatically provide rewards based on the visual inp
|
||||
3. **Train the classifier**:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json
|
||||
lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json
|
||||
```
|
||||
|
||||
4. **Test the classifier**:
|
||||
|
||||
+56
-30
@@ -32,9 +32,12 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
@@ -45,28 +48,40 @@ Available tasks:
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### Gym Wrappers Configuration
|
||||
### Processor Configuration
|
||||
|
||||
```json
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": -0.02
|
||||
},
|
||||
"reset": {
|
||||
"control_time_s": 15.0,
|
||||
"fixed_reset_joint_positions": [
|
||||
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
|
||||
]
|
||||
},
|
||||
"inverse_kinematics": {
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `gripper.use_gripper`: Whether to enable gripper control
|
||||
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
|
||||
|
||||
## Running with HIL RL of LeRobot
|
||||
@@ -75,39 +90,50 @@ Important parameters:
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0"
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes": 10,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Training a Policy
|
||||
|
||||
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
|
||||
+11
-11
@@ -19,7 +19,7 @@ pip install -e ".[hopejr]"
|
||||
Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
|
||||
@@ -31,7 +31,7 @@ Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibrati
|
||||
### 1.1 Calibrate Robot Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
@@ -81,7 +81,7 @@ Once you have set the appropriate boundaries for all joints, click "Save" to sav
|
||||
### 1.2 Calibrate Teleoperator Glove
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=homunculus_glove \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=red \
|
||||
@@ -120,7 +120,7 @@ Once calibration is complete, the system will save the calibration to `/Users/yo
|
||||
### 1.3 Calibrate Robot Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white
|
||||
@@ -146,7 +146,7 @@ Use the calibration interface to set the range boundaries for each joint. Move e
|
||||
### 1.4 Calibrate Teleoperator Exoskeleton
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=homunculus_arm \
|
||||
--teleop.port=/dev/tty.usbmodem11201 \
|
||||
--teleop.id=black
|
||||
@@ -178,7 +178,7 @@ Due to global variable conflicts in the Feetech middleware, teleoperation for ar
|
||||
### Hand
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=blue \
|
||||
@@ -194,7 +194,7 @@ python -m lerobot.teleoperate \
|
||||
### Arm
|
||||
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=hope_jr_arm \
|
||||
--robot.port=/dev/tty.usbserial-1110 \
|
||||
--robot.id=white \
|
||||
@@ -214,7 +214,7 @@ Record, Replay and Train with Hope-JR is still experimental.
|
||||
This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
@@ -236,7 +236,7 @@ python -m lerobot.record \
|
||||
### Replay
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
@@ -248,7 +248,7 @@ python -m lerobot.replay \
|
||||
### Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
@@ -263,7 +263,7 @@ python -m lerobot.scripts.train \
|
||||
This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=hope_jr_hand \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
|
||||
@@ -45,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file
|
||||
<hfoptions id="teleoperate_so101">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -101,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoptions id="teleoperate_koch_camera">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -174,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th
|
||||
<hfoptions id="record">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -376,7 +376,7 @@ You can replay the first episode on your robot with either the command below or
|
||||
<hfoptions id="replay">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
@@ -428,10 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
@@ -453,7 +453,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
@@ -490,7 +490,7 @@ You can use the `record` script from [`lerobot/record.py`](https://github.com/hu
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
|
||||
+55
-7
@@ -24,11 +24,36 @@ pip install -e ".[hilserl]"
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes": 30,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
@@ -96,10 +121,10 @@ If you uploaded your dataset to the hub you can [visualize your dataset online](
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/il_gym \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/il_sim_test \
|
||||
@@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
|
||||
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ pip install -e ".[dynamixel]"
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -98,7 +98,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -174,7 +174,7 @@ Do the same steps for the leader arm but modify the command or script accordingl
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -211,7 +211,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -249,7 +249,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -60,7 +60,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -116,7 +116,7 @@ The instructions for configuring the motors can be found in the SO101 [docs](./s
|
||||
You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=lekiwi \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -174,7 +174,7 @@ The calibration process is very important because it allows a neural network tra
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=lekiwi \
|
||||
--robot.id=my_awesome_kiwi # <- Give the robot a unique name
|
||||
```
|
||||
@@ -193,7 +193,7 @@ Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -54,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [.
|
||||
|
||||
```bash
|
||||
cd lerobot && python -m lerobot.scripts.train \
|
||||
cd lerobot && lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=${HF_USER}/mydataset \
|
||||
--batch_size=64 \
|
||||
@@ -73,7 +73,7 @@ cd lerobot && python -m lerobot.scripts.train \
|
||||
Fine-tuning is an art. For a complete overview of the options for finetuning, run
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --help
|
||||
lerobot-train --help
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -97,7 +97,7 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
|
||||
@@ -26,7 +26,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -93,7 +93,7 @@ For a visual reference on how to set the motor ids please refer to [this video](
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -168,7 +168,7 @@ Do the same steps for the leader arm.
|
||||
<hfoptions id="setup_motors">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -568,7 +568,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -606,7 +606,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -162,7 +162,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted:
|
||||
|
||||
```bash
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<hfoptions id="example">
|
||||
@@ -240,7 +240,7 @@ Connect the usb cable from your computer and the power supply to the follower ar
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -316,7 +316,7 @@ Do the same steps for the leader arm.
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
|
||||
```
|
||||
@@ -353,7 +353,7 @@ Run the following command or API example to calibrate the follower arm:
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
|
||||
@@ -402,7 +402,7 @@ Do the same steps to calibrate the leader arm, run the following command or API
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
|
||||
--teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
|
||||
|
||||
@@ -62,7 +62,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
||||
Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
--policy.type=diffusion \
|
||||
--env.type=pusht
|
||||
@@ -77,7 +77,7 @@ Let's break this down:
|
||||
Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
@@ -90,7 +90,7 @@ We now want to train a different policy for aloha on another task. We'll change
|
||||
Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
@@ -127,7 +127,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro
|
||||
We can then simply load the config values from this file using:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
```
|
||||
@@ -137,7 +137,7 @@ python -m lerobot.scripts.train \
|
||||
Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
|
||||
--output_dir=outputs/train/act_aloha_transfer_2
|
||||
--policy.n_action_steps=80
|
||||
@@ -148,7 +148,7 @@ python -m lerobot.scripts.train \
|
||||
`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
|
||||
lerobot-train --config_path=lerobot/diffusion_pusht
|
||||
```
|
||||
|
||||
will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
|
||||
@@ -160,7 +160,7 @@ Being able to resume a training run is important in case it crashed or aborted f
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \
|
||||
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
|
||||
--env.type=aloha \
|
||||
@@ -179,7 +179,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
|
||||
Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true
|
||||
```
|
||||
@@ -190,7 +190,7 @@ Another reason for which you might want to resume a run is simply to extend trai
|
||||
You could double the number of steps of the previous run with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000
|
||||
@@ -224,7 +224,7 @@ In addition to the features currently in Draccus, we've added a special `.path`
|
||||
For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
@@ -270,7 +270,7 @@ We'll summarize here the main use cases to remember from this tutorial.
|
||||
#### Train a policy from scratch – CLI
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=act \ # <- select 'act' policy
|
||||
--env.type=pusht \ # <- select 'pusht' environment
|
||||
--dataset.repo_id=lerobot/pusht # <- train on this dataset
|
||||
@@ -279,7 +279,7 @@ python -m lerobot.scripts.train \
|
||||
#### Train a policy from scratch - config file + CLI
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=path/to/pretrained_model \ # <- can also be a repo_id
|
||||
--policy.n_action_steps=80 # <- you may still override values
|
||||
```
|
||||
@@ -287,7 +287,7 @@ python -m lerobot.scripts.train \
|
||||
#### Resume/continue a training run
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
@@ -296,7 +296,7 @@ python -m lerobot.scripts.train \
|
||||
#### Fine-tuning
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
|
||||
--dataset.repo_id=lerobot/aloha_sim_insertion_human \
|
||||
--env.type=aloha \
|
||||
|
||||
@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
|
||||
@@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator).
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.calibrate \
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
|
||||
@@ -60,7 +60,7 @@ class OpenCVCamera(Camera):
|
||||
or port changes, especially on Linux. Use the provided utility script to find
|
||||
available camera indices or paths:
|
||||
```bash
|
||||
python -m lerobot.find_cameras opencv
|
||||
lerobot-find-cameras opencv
|
||||
```
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
@@ -165,8 +165,7 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
raise ConnectionError(
|
||||
f"Failed to open {self}."
|
||||
f"Run `python -m lerobot.find_cameras opencv` to find available cameras."
|
||||
f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras."
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
@@ -51,7 +51,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Use the provided utility script to find available camera indices and default profiles:
|
||||
```bash
|
||||
python -m lerobot.find_cameras realsense
|
||||
lerobot-find-cameras realsense
|
||||
```
|
||||
|
||||
A `RealSenseCamera` instance requires a configuration object specifying the
|
||||
@@ -176,8 +176,7 @@ class RealSenseCamera(Camera):
|
||||
self.rs_profile = None
|
||||
self.rs_pipeline = None
|
||||
raise ConnectionError(
|
||||
f"Failed to open {self}."
|
||||
"Run `python -m lerobot.find_cameras realsense` to find available cameras."
|
||||
f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras."
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
@@ -40,6 +40,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
|
||||
+57
-86
@@ -161,35 +161,73 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
class RewardClassifierConfig:
|
||||
"""Configuration for reward classification."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationConfig:
|
||||
"""Configuration for observation processing."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GripperConfig:
|
||||
"""Configuration for gripper control and penalties."""
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetConfig:
|
||||
"""Configuration for environment reset behavior."""
|
||||
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
control_time_s: float = 20.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlProcessorConfig:
|
||||
"""Configuration for environment processing pipeline."""
|
||||
|
||||
control_mode: str = "gamepad"
|
||||
observation: ObservationConfig | None = None
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None
|
||||
gripper: GripperConfig | None = None
|
||||
reset: ResetConfig | None = None
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
fps: int = 10
|
||||
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
||||
|
||||
name: str = "real_robot"
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str | None = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
name: str = "PandaPickCube"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ Helper to find the camera devices available in your system.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.find_cameras
|
||||
lerobot-find-cameras
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.find_port
|
||||
lerobot-find-port
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class MotorsBus(abc.ABC):
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python -m lerobot.find_port.py
|
||||
lerobot-find-port.py
|
||||
>>> Finding all available ports for the MotorsBus.
|
||||
>>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"]
|
||||
>>> Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
@@ -446,7 +446,7 @@ class MotorsBus(abc.ABC):
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python -m lerobot.find_port`\n"
|
||||
"\nTry running `lerobot-find-port`\n"
|
||||
) from e
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -45,6 +46,6 @@ def make_act_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -46,6 +47,6 @@ def make_diffusion_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -140,8 +140,6 @@ def make_processor(
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Load a pretrained processor
|
||||
# TODO(azouitine): Handle this case.
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
|
||||
@@ -1,420 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
@@ -30,7 +30,7 @@ pip install -e ".[pi0]"
|
||||
|
||||
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
|
||||
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
|
||||
pretrained with VLM default parameters before pi0 finetuning:
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -115,6 +116,6 @@ def make_pi0_processor(
|
||||
),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -25,14 +25,14 @@ Disclaimer: It is not expected to perform as well as the original implementation
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/pi0fast_base \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
Example of training the pi0+FAST neural network with from scratch:
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=pi0fast \
|
||||
--dataset.repo_id=danaaubakirova/koch_test
|
||||
```
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -46,6 +47,6 @@ def make_pi0fast_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -47,6 +48,6 @@ def make_sac_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
|
||||
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
and an action expert.
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -57,8 +58,8 @@ def make_smolvla_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -46,6 +47,6 @@ def make_tdmpc_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
@@ -47,6 +48,6 @@ def make_vqbet_processor(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="robot_postprocessor"
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
)
|
||||
|
||||
@@ -15,7 +15,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .batch_processor import ToBatchProcessor
|
||||
from .delta_action_processor import MapDeltaActionToRobotAction
|
||||
from .device_processor import DeviceProcessor
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryData,
|
||||
AddTeleopEventsAsInfo,
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
Numpy2TorchActionProcessor,
|
||||
RewardClassifierProcessor,
|
||||
TimeLimitProcessor,
|
||||
Torch2NumpyActionProcessor,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .pipeline import (
|
||||
@@ -37,11 +50,20 @@ from .tokenizer_processor import TokenizerProcessor
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"AddTeleopActionAsComplimentaryData",
|
||||
"AddTeleopEventsAsInfo",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"EnvTransition",
|
||||
"GripperPenaltyProcessor",
|
||||
"IdentityProcessor",
|
||||
"ImageCropResizeProcessor",
|
||||
"InfoProcessor",
|
||||
"InterventionActionProcessor",
|
||||
"JointVelocityProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"MotorCurrentProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"hotswap_stats",
|
||||
@@ -49,10 +71,14 @@ __all__ = [
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
"RewardClassifierProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"ToBatchProcessor",
|
||||
"TokenizerProcessor",
|
||||
"TimeLimitProcessor",
|
||||
"Numpy2TorchActionProcessor",
|
||||
"Torch2NumpyActionProcessor",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
gripper_deadzone: float = 0.1 # Threshold for gripper activation
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, action: dict | Tensor | None) -> dict:
|
||||
if action is None:
|
||||
return {}
|
||||
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
if isinstance(action, dict):
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
else:
|
||||
delta_x = action[0].item()
|
||||
delta_y = action[1].item()
|
||||
delta_z = action[2].item()
|
||||
gripper = action[3].item()
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = abs(delta_x) + abs(delta_y) + abs(delta_z)
|
||||
enabled = position_magnitude > 1e-6 # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = float(delta_x) * self.position_scale
|
||||
scaled_delta_y = float(delta_y) * self.position_scale
|
||||
scaled_delta_z = float(delta_z) * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": scaled_delta_x,
|
||||
"action.target_y": scaled_delta_y,
|
||||
"action.target_z": scaled_delta_z,
|
||||
"action.target_wx": target_wx,
|
||||
"action.target_wy": target_wy,
|
||||
"action.target_wz": target_wz,
|
||||
"action.gripper": float(gripper),
|
||||
}
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
# Update features to reflect the new action format
|
||||
features.update(
|
||||
{
|
||||
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
)
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
@@ -66,9 +66,26 @@ class DeviceProcessor:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype."""
|
||||
# Move to device first
|
||||
tensor = tensor.to(self.device, non_blocking=self.non_blocking)
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self._device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self._device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified and tensor is floating point
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict | None) -> dict:
|
||||
complementary_data = {} if complementary_data is None else dict(complementary_data)
|
||||
complementary_data["teleop_action"] = self.teleop_device.get_action()
|
||||
return complementary_data
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfo(InfoProcessor):
|
||||
"""Add teleoperator control events to transition info."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def info(self, info: dict | None) -> dict:
|
||||
info = {} if info is None else dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
info.update(teleop_events)
|
||||
return info
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor | None) -> np.ndarray | None:
|
||||
if action is None:
|
||||
return None
|
||||
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray | None) -> torch.Tensor | None:
|
||||
if action is None:
|
||||
return None
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = torch.from_numpy(action)
|
||||
return torch_action
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessor(ObservationProcessor):
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if self.crop_params_dict is not None and key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
if self.resize_size is not None:
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
if "image" in key:
|
||||
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessor:
|
||||
"""Track episode steps and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return transition
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessor:
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
if complementary_data is None or action is None:
|
||||
return transition
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return transition
|
||||
|
||||
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Add penalty information to complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||
|
||||
# Create new transition with updated complementary data
|
||||
new_transition = transition.copy()
|
||||
existing_comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
existing_comp_data.update(new_complementary_data)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = existing_comp_data # type: ignore[misc]
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessor:
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
teleop_action = info.get("teleop_action", {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action is not None:
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get("action.delta_x", 0.0),
|
||||
teleop_action.get("action.delta_y", 0.0),
|
||||
teleop_action.get("action.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get("gripper", 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
action_list = teleop_action
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
|
||||
self.terminate_on_success and success
|
||||
)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info[TeleopEvents.IS_INTERVENTION] = is_intervention
|
||||
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
|
||||
info[TeleopEvents.SUCCESS] = success
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessor:
|
||||
"""Apply reward classification to image observations."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
|
||||
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if success == 1.0:
|
||||
reward = self.success_reward
|
||||
if self.terminate_on_success:
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
# Update info with classifier frequency
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["reward_classifier_frequency"] = classifier_frequency
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,116 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor:
|
||||
"""Add joint velocity information to observations."""
|
||||
|
||||
joint_velocity_limits: float = 100.0
|
||||
dt: float = 1.0 / 10
|
||||
num_dof: int | None = None
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
if current_positions is None:
|
||||
return observation
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"joint_velocity_limits": self.joint_velocity_limits,
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features and self.num_dof is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] + self.num_dof,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessor(ObservationProcessor):
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict | None) -> dict | None:
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
return observation
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get("observation.state")
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features and self.robot is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
@@ -134,9 +134,19 @@ class TokenizerProcessor:
|
||||
if task is None:
|
||||
return transition
|
||||
|
||||
# Tokenize the task
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
target_device = self._detect_device(transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
@@ -153,6 +163,45 @@ class TokenizerProcessor:
|
||||
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
|
||||
return transition
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
# Check other tensor fields
|
||||
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check complementary data for tensors
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
for value in complementary_data.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Records a dataset. Actions for the robot can be either generated by teleoperatio
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
|
||||
@@ -36,7 +36,7 @@ python -m lerobot.record \
|
||||
|
||||
Example recording with bimanual so100:
|
||||
```shell
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
|
||||
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
|
||||
|
||||
@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
|
||||
Examples:
|
||||
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
@@ -28,7 +28,7 @@ python -m lerobot.replay \
|
||||
|
||||
Example replay with bimanual so100:
|
||||
```shell
|
||||
python -m lerobot.replay \
|
||||
lerobot-replay \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
|
||||
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
|
||||
|
||||
@@ -53,6 +53,9 @@ class EEReferenceAndDelta:
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: dict
|
||||
motor_names: list[str]
|
||||
use_latched_reference: bool = (
|
||||
True # If True, latch reference on enable; if False, always use current pose
|
||||
)
|
||||
|
||||
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
@@ -69,7 +72,10 @@ class EEReferenceAndDelta:
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
if "reference_joint_positions" in comp:
|
||||
q = comp["reference_joint_positions"]
|
||||
else:
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
@@ -85,11 +91,12 @@ class EEReferenceAndDelta:
|
||||
desired = None
|
||||
|
||||
if enabled:
|
||||
# Latch a reference at the rising edge; also be defensive if None
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
ref = t_curr
|
||||
if self.use_latched_reference:
|
||||
# Latched reference mode: latch reference at the rising edge
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
|
||||
delta_p = np.array(
|
||||
[
|
||||
@@ -100,7 +107,6 @@ class EEReferenceAndDelta:
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
@@ -292,6 +298,8 @@ class InverseKinematicsEEToJoints:
|
||||
else:
|
||||
new_act[f"action.{name}.pos"] = float(q_target[i])
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
@@ -332,6 +340,7 @@ class GripperVelocityToJoint:
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
@@ -347,6 +356,15 @@ class GripperVelocityToJoint:
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get("action.gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act["action.gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
@@ -141,10 +141,10 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/aloha_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_aloha_test \
|
||||
|
||||
@@ -21,7 +21,7 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
|
||||
for 10 episodes.
|
||||
|
||||
```
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/diffusion_pusht \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
@@ -32,7 +32,7 @@ python -m lerobot.scripts.eval \
|
||||
|
||||
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
|
||||
```
|
||||
python -m lerobot.scripts.eval \
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
|
||||
@@ -62,9 +62,16 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
@@ -236,7 +243,8 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -257,6 +265,12 @@ def act_with_policy(
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -274,45 +288,61 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
observation = transition[TransitionKey.OBSERVATION]
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# Use the new step function
|
||||
new_transition = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = new_transition[TransitionKey.OBSERVATION]
|
||||
executed_action = new_transition[TransitionKey.ACTION]
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
complementary_info = {
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
@@ -347,12 +377,20 @@ def act_with_policy(
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
@@ -1174,7 +1175,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
@@ -152,6 +153,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
|
||||
postprocessor.from_pretrained(
|
||||
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -209,10 +214,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
@@ -244,7 +245,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor)
|
||||
save_checkpoint(
|
||||
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
@@ -288,10 +291,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
if preprocessor:
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
if postprocessor:
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any, Callable
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
is_launched_with_accelerate,
|
||||
)
|
||||
from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator: Callable = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
policy.train()
|
||||
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
accelerator.backward(loss)
|
||||
accelerator.unscale_gradients(optimizer=optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Callable):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
set_seed(cfg.seed, accelerator=accelerator)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
policy.to(device)
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
if accelerator.is_main_process:
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
initial_step=step,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and accelerator.is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps and accelerator.is_main_process
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and accelerator.is_main_process
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
cfg,
|
||||
accelerator.unwrap_model(policy),
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
env=eval_env,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
eval_metrics,
|
||||
initial_step=step,
|
||||
accelerator=None,
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||
# adjusting the lr_scheduler steps based on the num_processes
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False)
|
||||
train(accelerator=accelerator)
|
||||
@@ -18,7 +18,7 @@ Helper to set motor ids and baudrate.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.setup_motors \
|
||||
lerobot-setup-motors \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem575E0031751
|
||||
```
|
||||
|
||||
@@ -18,7 +18,7 @@ Simple script to control a robot from teleoperation.
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
@@ -32,7 +32,7 @@ python -m lerobot.teleoperate \
|
||||
Example teleoperation with bimanual so100:
|
||||
|
||||
```shell
|
||||
python -m lerobot.teleoperate \
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_so100_follower \
|
||||
--robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
|
||||
--robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
from .utils import make_teleoperator_from_config
|
||||
from .utils import TeleopEvents, make_teleoperator_from_config
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -134,10 +136,10 @@ class KeyboardController(InputController):
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -255,13 +257,13 @@ class GamepadController(InputController):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
@@ -451,11 +453,11 @@ class GamepadControllerHID(InputController):
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
|
||||
|
||||
@@ -93,9 +94,9 @@ class GamepadTeleop(Teleoperator):
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
action_dict = {
|
||||
"delta_x": gamepad_action[0],
|
||||
"delta_y": gamepad_action[1],
|
||||
"delta_z": gamepad_action[2],
|
||||
"action.delta_x": gamepad_action[0],
|
||||
"action.delta_y": gamepad_action[1],
|
||||
"action.delta_z": gamepad_action[2],
|
||||
}
|
||||
|
||||
# Default gripper action is to stay
|
||||
@@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the gamepad such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if self.gamepad is None:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Update gamepad state to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
# Check if intervention is active
|
||||
is_intervention = self.gamepad.should_intervene()
|
||||
|
||||
# Get episode end status
|
||||
episode_end_status = self.gamepad.get_episode_end_status()
|
||||
terminate_episode = episode_end_status in [
|
||||
TeleopEvents.RERECORD_EPISODE,
|
||||
TeleopEvents.FAILURE,
|
||||
]
|
||||
success = episode_end_status == TeleopEvents.SUCCESS
|
||||
rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the gamepad."""
|
||||
if self.gamepad is not None:
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
@@ -167,13 +168,13 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2, "action.gripper": 3},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
|
||||
}
|
||||
|
||||
def _on_press(self, key):
|
||||
@@ -226,12 +227,75 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
self.current_pressed.clear()
|
||||
|
||||
action_dict = {
|
||||
"delta_x": delta_x,
|
||||
"delta_y": delta_y,
|
||||
"delta_z": delta_z,
|
||||
"action.delta_x": delta_x,
|
||||
"action.delta_y": delta_y,
|
||||
"action.delta_z": delta_z,
|
||||
}
|
||||
|
||||
if self.config.use_gripper:
|
||||
action_dict["gripper"] = gripper_action
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the keyboard such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Keyboard mappings:
|
||||
- Any movement keys pressed = intervention active
|
||||
- 's' key = success (terminate episode successfully)
|
||||
- 'r' key = rerecord episode (terminate and rerecord)
|
||||
- 'q' key = quit episode (terminate without success)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Check if any movement keys are currently pressed (indicates intervention)
|
||||
movement_keys = [
|
||||
keyboard.Key.up,
|
||||
keyboard.Key.down,
|
||||
keyboard.Key.left,
|
||||
keyboard.Key.right,
|
||||
keyboard.Key.shift,
|
||||
keyboard.Key.shift_r,
|
||||
keyboard.Key.ctrl_r,
|
||||
keyboard.Key.ctrl_l,
|
||||
]
|
||||
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
|
||||
|
||||
# Check for episode control commands from misc_keys_queue
|
||||
terminate_episode = False
|
||||
success = False
|
||||
rerecord_episode = False
|
||||
|
||||
# Process any pending misc keys
|
||||
while not self.misc_keys_queue.empty():
|
||||
key = self.misc_keys_queue.get_nowait()
|
||||
if key == "s":
|
||||
success = True
|
||||
elif key == "r":
|
||||
terminate_episode = True
|
||||
rerecord_episode = True
|
||||
elif key == "q":
|
||||
terminate_episode = True
|
||||
success = False
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
@@ -12,10 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
|
||||
class TeleopEvents(Enum):
|
||||
"""Shared constants for teleoperator events across teleoperators."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
RERECORD_EPISODE = "rerecord_episode"
|
||||
IS_INTERVENTION = "is_intervention"
|
||||
TERMINATE_EPISODE = "terminate_episode"
|
||||
|
||||
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
if config.type == "keyboard":
|
||||
from .keyboard import KeyboardTeleop
|
||||
|
||||
@@ -44,7 +44,7 @@ Below is the short version on how to train and run inference/eval:
|
||||
### Train from scratch
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/<desired_policy_repo_id> \
|
||||
@@ -59,7 +59,7 @@ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
||||
### Evaluate the policy/run inference
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from lerobot.utils.utils import format_big_number
|
||||
|
||||
@@ -84,6 +84,7 @@ class MetricsTracker:
|
||||
"samples",
|
||||
"episodes",
|
||||
"epochs",
|
||||
"accelerator",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -93,12 +94,14 @@ class MetricsTracker:
|
||||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
self.metrics = metrics
|
||||
self.accelerator = accelerator
|
||||
|
||||
self.steps = initial_step
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
@@ -128,7 +131,7 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import random
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -164,7 +164,7 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
def set_seed(seed: int, accelerator: Callable | None = None) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -172,6 +172,11 @@ def set_seed(seed) -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if accelerator:
|
||||
from accelerate.utils import set_seed as accelerate_set_seed
|
||||
|
||||
accelerate_set_seed(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
|
||||
@@ -17,10 +17,9 @@ import time
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
if platform.system() == "Darwin":
|
||||
# On Mac, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
if platform.system() == "Darwin" or platform.system() == "Windows":
|
||||
# On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
@@ -74,7 +75,8 @@ def save_checkpoint(
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor=None,
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -105,6 +107,8 @@ def save_checkpoint(
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if preprocessor is not None:
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import time
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
@@ -56,13 +57,15 @@ def auto_select_torch_device() -> torch.device:
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
def get_safe_torch_device(
|
||||
try_device: str, log: bool = False, accelerator: Callable | None = None
|
||||
) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
device = accelerator.device if accelerator else torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
@@ -116,6 +119,7 @@ def init_logging(
|
||||
display_pid: bool = False,
|
||||
console_level: str = "INFO",
|
||||
file_level: str = "DEBUG",
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -152,6 +156,11 @@ def init_logging(
|
||||
file_handler.setLevel(file_level.upper())
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -165,6 +174,10 @@ def format_big_number(num, precision=0):
|
||||
return num
|
||||
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
size 33400
|
||||
oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e
|
||||
size 31672
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c
|
||||
size 515400
|
||||
|
||||
+2
-2
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
size 33400
|
||||
oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
|
||||
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
|
||||
size 992
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
|
||||
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
|
||||
size 47424
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
|
||||
size 49120
|
||||
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
|
||||
size 47424
|
||||
|
||||
@@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_processor
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
@@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
dataset_stats = dataset.meta.stats
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_processor(train_cfg.policy, dataset_stats=dataset_stats)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
@@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
@@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
else:
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {}
|
||||
for i in range(actions_queue):
|
||||
unnormalized_action = policy.select_action(obs).contiguous()
|
||||
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
|
||||
actions[str(i)] = action_robot
|
||||
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
|
||||
oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61
|
||||
size 200
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
|
||||
size 16904
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
|
||||
size 164
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
||||
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
|
||||
size 35496
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
|
||||
oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d
|
||||
size 200
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
|
||||
size 16904
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
|
||||
size 164
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
||||
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
|
||||
size 35496
|
||||
|
||||
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
@@ -41,7 +41,6 @@ from lerobot.policies.factory import (
|
||||
make_policy_config,
|
||||
make_processor,
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -266,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_features = {
|
||||
"observation.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 96, 96),
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
@@ -467,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
|
||||
and you changed your processor, you might have to update the test artifact.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for ACT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default ACT configuration for testing."""
|
||||
config = ACTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_act_processor_basic():
|
||||
"""Test basic creation of ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_act_processor_normalization():
|
||||
"""Test that ACT processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
assert postprocessed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_cuda():
|
||||
"""Test ACT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_accelerate_scenario():
|
||||
"""Test ACT processor in simulated Accelerate scenario (data already on GPU)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU (not moved unnecessarily)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_act_processor_multi_gpu():
|
||||
"""Test ACT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU (like in multi-GPU training)
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1 (not moved to cuda:0)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_act_processor_without_stats():
|
||||
"""Test ACT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors, but normalization won't have stats
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_act_processor_save_and_load():
|
||||
"""Test saving and loading ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
def test_act_processor_device_placement_preservation():
|
||||
"""Test that ACT processor preserves device placement correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_processor(config, stats)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_mixed_precision():
|
||||
"""Test ACT processor with mixed precision (float16)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_act_processor_batch_consistency():
|
||||
"""Test that ACT processor handles different batch sizes correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_processor(config, stats)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
|
||||
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for Reward Classifier processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
from lerobot.processor import DeviceProcessor, IdentityProcessor, NormalizerProcessor, RobotProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default Reward Classifier configuration for testing."""
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.IDENTITY, # No normalization for classifier output
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
"reward": {}, # No normalization for classifier output
|
||||
}
|
||||
|
||||
|
||||
def test_make_classifier_processor_basic():
|
||||
"""Test basic creation of Classifier processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "classifier_preprocessor"
|
||||
assert postprocessor.name == "classifier_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 3
|
||||
assert isinstance(preprocessor.steps[0], NormalizerProcessor) # For input features
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor) # For output features
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], IdentityProcessor)
|
||||
|
||||
|
||||
def test_classifier_processor_normalization():
|
||||
"""Test that Classifier processor correctly normalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1) # Dummy action/reward
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is processed
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1,)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_classifier_processor_cuda():
|
||||
"""Test Classifier processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
reward_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(reward_transition)
|
||||
|
||||
# Check that output is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_classifier_processor_accelerate_scenario():
|
||||
"""Test Classifier processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_classifier_processor_multi_gpu():
|
||||
"""Test Classifier processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_classifier_processor_without_stats():
|
||||
"""Test Classifier processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_classifier_processor_save_and_load():
|
||||
"""Test saving and loading Classifier processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1,)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_classifier_processor_mixed_precision():
|
||||
"""Test Classifier processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(1, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_classifier_processor_batch_data():
|
||||
"""Test Classifier processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 16
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 10),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 10)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (batch_size, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (batch_size, 1)
|
||||
|
||||
|
||||
def test_classifier_processor_postprocessor_identity():
|
||||
"""Test that Classifier postprocessor uses IdentityProcessor correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
transition = create_transition(action=reward)
|
||||
|
||||
# Process through postprocessor
|
||||
processed = postprocessor(transition)
|
||||
|
||||
# IdentityProcessor should leave values unchanged (except device)
|
||||
assert torch.allclose(processed[TransitionKey.ACTION].cpu(), reward.cpu())
|
||||
assert processed[TransitionKey.ACTION].device.type == "cpu"
|
||||
@@ -820,6 +820,143 @@ def test_complementary_data_none():
|
||||
assert TransitionKey.COMPLEMENTARY_DATA not in result
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_preserves_gpu_placement():
|
||||
"""Test that DeviceProcessor preserves GPU placement when tensor is already on GPU."""
|
||||
processor = DeviceProcessor(device="cuda:0")
|
||||
|
||||
# Create tensors already on GPU
|
||||
observation = {
|
||||
"observation.state": torch.randn(10).cuda(), # Already on GPU
|
||||
"observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU
|
||||
}
|
||||
action = torch.randn(5).cuda() # Already on GPU
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tensors remain on their original GPU
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Verify no unnecessary copies were made (same data pointer)
|
||||
assert torch.equal(
|
||||
result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_multi_gpu_preservation():
|
||||
"""Test that DeviceProcessor preserves placement on different GPUs in multi-GPU setup."""
|
||||
# Test 1: GPU-to-GPU preservation (cuda:0 config, cuda:1 input)
|
||||
processor_gpu = DeviceProcessor(device="cuda:0")
|
||||
|
||||
# Create tensors on cuda:1 (simulating Accelerate placement)
|
||||
cuda1_device = torch.device("cuda:1")
|
||||
observation = {
|
||||
"observation.state": torch.randn(10).to(cuda1_device),
|
||||
"observation.image": torch.randn(3, 224, 224).to(cuda1_device),
|
||||
}
|
||||
action = torch.randn(5).to(cuda1_device)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor_gpu(transition)
|
||||
|
||||
# Check that tensors remain on cuda:1 (not moved to cuda:0)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device
|
||||
assert result[TransitionKey.ACTION].device == cuda1_device
|
||||
|
||||
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
|
||||
processor_cpu = DeviceProcessor(device="cpu")
|
||||
|
||||
transition_gpu = create_transition(
|
||||
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda()
|
||||
)
|
||||
result_cpu = processor_cpu(transition_gpu)
|
||||
|
||||
# Check that tensors are moved to CPU
|
||||
assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result_cpu[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_multi_gpu_with_cpu_tensors():
|
||||
"""Test that CPU tensors are moved to configured device even in multi-GPU context."""
|
||||
# Processor configured for cuda:1
|
||||
processor = DeviceProcessor(device="cuda:1")
|
||||
|
||||
# Mix of CPU and GPU tensors
|
||||
observation = {
|
||||
"observation.cpu": torch.randn(10), # CPU tensor
|
||||
"observation.gpu0": torch.randn(10).cuda(0), # Already on cuda:0
|
||||
"observation.gpu1": torch.randn(10).cuda(1), # Already on cuda:1
|
||||
}
|
||||
action = torch.randn(5) # CPU tensor
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
# CPU tensor should move to configured device (cuda:1)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 1
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.index == 1
|
||||
|
||||
# GPU tensors should stay on their original devices
|
||||
assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0
|
||||
assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_multi_gpu_with_float_dtype():
|
||||
"""Test float dtype conversion works correctly with multi-GPU preservation."""
|
||||
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
|
||||
|
||||
# Create float tensors on different GPUs
|
||||
observation = {
|
||||
"observation.gpu0": torch.randn(5, dtype=torch.float32).cuda(0),
|
||||
"observation.gpu1": torch.randn(5, dtype=torch.float32).cuda(1),
|
||||
"observation.cpu": torch.randn(5, dtype=torch.float32), # CPU
|
||||
}
|
||||
|
||||
transition = create_transition(observation=observation)
|
||||
result = processor(transition)
|
||||
|
||||
# Check device placement
|
||||
assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0
|
||||
assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1
|
||||
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 0 # Moved to cuda:0
|
||||
|
||||
# Check dtype conversion happened for all
|
||||
assert result[TransitionKey.OBSERVATION]["observation.gpu0"].dtype == torch.float16
|
||||
assert result[TransitionKey.OBSERVATION]["observation.gpu1"].dtype == torch.float16
|
||||
assert result[TransitionKey.OBSERVATION]["observation.cpu"].dtype == torch.float16
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_simulated_accelerate_scenario():
|
||||
"""Test a scenario simulating how Accelerate would use the processor."""
|
||||
# Simulate different processes getting different GPU assignments
|
||||
for gpu_id in range(min(torch.cuda.device_count(), 2)):
|
||||
# Each "process" has a processor configured for cuda:0
|
||||
# but data comes in already placed on the process's GPU
|
||||
processor = DeviceProcessor(device="cuda:0")
|
||||
|
||||
# Simulate data already placed by Accelerate
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
observation = {"observation.state": torch.randn(1, 10).to(device)}
|
||||
action = torch.randn(1, 5).to(device)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
# Verify data stays on the GPU where Accelerate placed it
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device == device
|
||||
assert result[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_policy_processor_integration():
|
||||
"""Test integration with policy processors - input on GPU, output on CPU."""
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for Diffusion policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default Diffusion configuration for testing."""
|
||||
config = DiffusionConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_diffusion_processor_basic():
|
||||
"""Test basic creation of Diffusion processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_diffusion_processor_with_images():
|
||||
"""Test Diffusion processor with image observations."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Create test data with images
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_cuda():
|
||||
"""Test Diffusion processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_accelerate_scenario():
|
||||
"""Test Diffusion processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 7).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_diffusion_processor_multi_gpu():
|
||||
"""Test Diffusion processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 7).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_diffusion_processor_without_stats():
|
||||
"""Test Diffusion processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_diffusion_processor_save_and_load():
|
||||
"""Test saving and loading Diffusion processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_mixed_precision():
|
||||
"""Test Diffusion processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_diffusion_processor_identity_normalization():
|
||||
"""Test that images with IDENTITY normalization are not normalized."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
image_value = torch.rand(3, 224, 224) * 255 # Large values
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: image_value.clone(),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Image should not be normalized (IDENTITY mode)
|
||||
# Just batched
|
||||
assert torch.allclose(processed[TransitionKey.OBSERVATION][OBS_IMAGE][0], image_value, rtol=1e-5)
|
||||
|
||||
|
||||
def test_diffusion_processor_batch_consistency():
|
||||
"""Test Diffusion processor with different batch sizes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_processor(config, stats)
|
||||
|
||||
# Test with different batch sizes
|
||||
for batch_size in [1, 8, 32]:
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 7) if batch_size > 1 else torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check correct batch size
|
||||
expected_batch = batch_size if batch_size > 1 else 1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch
|
||||
assert processed[TransitionKey.ACTION].shape[0] == expected_batch
|
||||
@@ -25,6 +25,7 @@ from lerobot.processor.normalize_processor import (
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
hotswap_stats,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
|
||||
@@ -1604,3 +1605,156 @@ def test_hotswap_stats_functional_test():
|
||||
new_result["observation"]["observation.image"], observation["observation.image"]
|
||||
)
|
||||
assert not torch.allclose(new_result["action"], action)
|
||||
|
||||
|
||||
def test_zero_std_uses_eps():
|
||||
"""When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0."""
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
|
||||
|
||||
observation = {"observation.state": torch.tensor([0.5])} # equals mean
|
||||
out = normalizer(create_transition(observation=observation))
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0]))
|
||||
|
||||
|
||||
def test_min_equals_max_maps_to_minus_one():
|
||||
"""When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min."""
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX}
|
||||
stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
|
||||
|
||||
observation = {"observation.state": torch.tensor([2.0])}
|
||||
out = normalizer(create_transition(observation=observation))
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0]))
|
||||
|
||||
|
||||
def test_action_normalized_despite_normalize_keys():
|
||||
"""Action normalization is independent of normalize_keys filter for observations."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"}
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0])
|
||||
)
|
||||
out = normalizer(transition)
|
||||
# (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0
|
||||
assert torch.allclose(out[TransitionKey.ACTION], torch.tensor([1.0, 1.0]))
|
||||
|
||||
|
||||
def test_unnormalize_observations_mean_std_and_min_max():
|
||||
features = {
|
||||
"observation.ms": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"observation.mm": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
# Build two processors: one mean/std and one min/max
|
||||
unnorm_ms = UnnormalizerProcessor(
|
||||
features={"observation.ms": features["observation.ms"]},
|
||||
norm_map={FeatureType.STATE: NormalizationMode.MEAN_STD},
|
||||
stats={"observation.ms": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}},
|
||||
)
|
||||
unnorm_mm = UnnormalizerProcessor(
|
||||
features={"observation.mm": features["observation.mm"]},
|
||||
norm_map={FeatureType.STATE: NormalizationMode.MIN_MAX},
|
||||
stats={"observation.mm": {"min": np.array([0.0, -2.0]), "max": np.array([2.0, 2.0])}},
|
||||
)
|
||||
|
||||
tr = create_transition(
|
||||
observation={
|
||||
"observation.ms": torch.tensor([0.0, 0.0]), # → mean
|
||||
"observation.mm": torch.tensor([0.0, 0.0]), # → mid-point
|
||||
}
|
||||
)
|
||||
out_ms = unnorm_ms(tr)[TransitionKey.OBSERVATION]["observation.ms"]
|
||||
out_mm = unnorm_mm(tr)[TransitionKey.OBSERVATION]["observation.mm"]
|
||||
assert torch.allclose(out_ms, torch.tensor([1.0, -1.0]))
|
||||
assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2]
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
|
||||
def test_unknown_observation_keys_ignored():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])}
|
||||
tr = create_transition(observation=obs)
|
||||
out = normalizer(tr)
|
||||
|
||||
# Unknown key should pass through unchanged and not be tracked
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"])
|
||||
comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"]
|
||||
|
||||
|
||||
def test_batched_action_normalization():
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1]
|
||||
out = normalizer(create_transition(action=actions))[TransitionKey.ACTION]
|
||||
expected = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
|
||||
assert torch.allclose(out, expected)
|
||||
|
||||
|
||||
def test_complementary_data_preservation():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
comp = {"existing": 123}
|
||||
tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp)
|
||||
out = normalizer(tr)
|
||||
new_comp = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert new_comp["existing"] == 123 and "normalized_keys" in new_comp
|
||||
|
||||
|
||||
def test_roundtrip_normalize_unnormalize_non_identity():
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX}
|
||||
stats = {
|
||||
"observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])},
|
||||
"action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])},
|
||||
}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Add a time dimension in action for broadcasting check (B,T,D)
|
||||
obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])}
|
||||
act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1]
|
||||
|
||||
tr = create_transition(observation=obs, action=act)
|
||||
out = unnormalizer(normalizer(tr))
|
||||
|
||||
assert torch.allclose(
|
||||
out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5
|
||||
)
|
||||
assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5)
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for PI0 policy processor."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default PI0 configuration for testing."""
|
||||
config = PI0Config()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
config.tokenizer_max_length = 128
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_pi0_processor_basic():
|
||||
"""Test basic creation of PI0 processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], Pi0NewLineProcessor)
|
||||
# Step 4 would be TokenizerProcessor but it's mocked
|
||||
assert isinstance(preprocessor.steps[5], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_pi0_newline_processor_single_task():
|
||||
"""Test Pi0NewLineProcessor with single task string."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with task that doesn't have newline
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
# Test with task that already has newline
|
||||
transition = create_transition(complementary_data={"task": "test task\n"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
|
||||
def test_pi0_newline_processor_list_of_tasks():
|
||||
"""Test Pi0NewLineProcessor with list of task strings."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with list of tasks
|
||||
tasks = ["task1", "task2\n", "task3"]
|
||||
transition = create_transition(complementary_data={"task": tasks})
|
||||
result = processor(transition)
|
||||
expected = ["task1\n", "task2\n", "task3\n"]
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
|
||||
|
||||
|
||||
def test_pi0_newline_processor_empty_transition():
|
||||
"""Test Pi0NewLineProcessor with empty transition."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with no complementary_data
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with complementary_data but no task
|
||||
transition = create_transition(complementary_data={"other": "data"})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with None task
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_cuda():
|
||||
"""Test PI0 processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_accelerate_scenario():
|
||||
"""Test PI0 processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 10).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_pi0_processor_multi_gpu():
|
||||
"""Test PI0 processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 10).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_pi0_processor_without_stats():
|
||||
"""Test PI0 processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_pi0_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
|
||||
def test_pi0_newline_processor_state_dict():
|
||||
"""Test Pi0NewLineProcessor state dict methods."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test state_dict (should be empty)
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
# Test load_state_dict (should do nothing)
|
||||
processor.load_state_dict({})
|
||||
|
||||
# Test reset (should do nothing)
|
||||
processor.reset()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {}
|
||||
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for SAC policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SAC configuration for testing."""
|
||||
config = SACConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
ACTION: {"min": torch.full((5,), -1.0), "max": torch.ones(5)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_sac_processor_basic():
|
||||
"""Test basic creation of SAC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_sac_processor_normalization_modes():
|
||||
"""Test that SAC processor correctly handles different normalization modes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
action = torch.rand(5) * 2 - 1 # Range [-1, 1]
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
# State should be mean-std normalized
|
||||
# Action should be min-max normalized to [-1, 1]
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized (but still batched)
|
||||
assert postprocessed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_cuda():
|
||||
"""Test SAC processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_accelerate_scenario():
|
||||
"""Test SAC processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_sac_processor_multi_gpu():
|
||||
"""Test SAC processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_sac_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_sac_processor_save_and_load():
|
||||
"""Test saving and loading SAC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_mixed_precision():
|
||||
"""Test SAC processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_sac_processor_batch_data():
|
||||
"""Test SAC processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 32
|
||||
observation = {OBS_STATE: torch.randn(batch_size, 10)}
|
||||
action = torch.randn(batch_size, 5)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 10)
|
||||
assert processed[TransitionKey.ACTION].shape == (batch_size, 5)
|
||||
|
||||
|
||||
def test_sac_processor_edge_cases():
|
||||
"""Test SAC processor with edge cases."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_processor(config, stats)
|
||||
|
||||
# Test with empty observation
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION] == {}
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
# When action is None, it may still be present with None value
|
||||
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None
|
||||
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for SmolVLA policy processor."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.processor_smolvla import SmolVLANewLineProcessor, make_smolvla_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SmolVLA configuration for testing."""
|
||||
config = SmolVLAConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
config.vlm_model_name = "HuggingFaceTB/SmolVLM-Instruct"
|
||||
config.pad_language_to = "max_length"
|
||||
config.tokenizer_max_length = 100
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_smolvla_processor_basic():
|
||||
"""Test basic creation of SmolVLA processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], SmolVLANewLineProcessor)
|
||||
# Step 4 would be TokenizerProcessor but it's mocked
|
||||
assert isinstance(preprocessor.steps[5], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_single_task():
|
||||
"""Test SmolVLANewLineProcessor with single task string."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test with task that doesn't have newline
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
# Test with task that already has newline
|
||||
transition = create_transition(complementary_data={"task": "test task\n"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_list_of_tasks():
|
||||
"""Test SmolVLANewLineProcessor with list of task strings."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test with list of tasks
|
||||
tasks = ["task1", "task2\n", "task3"]
|
||||
transition = create_transition(complementary_data={"task": tasks})
|
||||
result = processor(transition)
|
||||
expected = ["task1\n", "task2\n", "task3\n"]
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_empty_transition():
|
||||
"""Test SmolVLANewLineProcessor with empty transition."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test with no complementary_data
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with complementary_data but no task
|
||||
transition = create_transition(complementary_data={"other": "data"})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with None task
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_smolvla_processor_cuda():
|
||||
"""Test SmolVLA processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_smolvla_processor_accelerate_scenario():
|
||||
"""Test SmolVLA processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_smolvla_processor_multi_gpu():
|
||||
"""Test SmolVLA processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_smolvla_processor_without_stats():
|
||||
"""Test SmolVLA processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_smolvla_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_state_dict():
|
||||
"""Test SmolVLANewLineProcessor state dict methods."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test state_dict (should be empty)
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
# Test load_state_dict (should do nothing)
|
||||
processor.load_state_dict({})
|
||||
|
||||
# Test reset (should do nothing)
|
||||
processor.reset()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_transform_features():
|
||||
"""Test SmolVLANewLineProcessor transform_features method."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test transform_features
|
||||
features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
}
|
||||
result = processor.transform_features(features)
|
||||
assert result == features # Should return unchanged
|
||||
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for TDMPC policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default TDMPC configuration for testing."""
|
||||
config = TDMPCConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(12,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(12), "std": torch.ones(12)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_tdmpc_processor_basic():
|
||||
"""Test basic creation of TDMPC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_tdmpc_processor_normalization():
|
||||
"""Test that TDMPC processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is processed and batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 12)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 6)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized (but still batched)
|
||||
assert postprocessed[TransitionKey.ACTION].shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_cuda():
|
||||
"""Test TDMPC processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_accelerate_scenario():
|
||||
"""Test TDMPC processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_tdmpc_processor_multi_gpu():
|
||||
"""Test TDMPC processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_tdmpc_processor_without_stats():
|
||||
"""Test TDMPC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_tdmpc_processor_save_and_load():
|
||||
"""Test saving and loading TDMPC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 12)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_mixed_precision():
|
||||
"""Test TDMPC processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_tdmpc_processor_batch_data():
|
||||
"""Test TDMPC processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 64
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 12),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 12)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (batch_size, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (batch_size, 6)
|
||||
|
||||
|
||||
def test_tdmpc_processor_edge_cases():
|
||||
"""Test TDMPC processor with edge cases."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_processor(config, stats)
|
||||
|
||||
# Test with only state observation (no image)
|
||||
observation = {OBS_STATE: torch.randn(12)}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 12)
|
||||
assert OBS_IMAGE not in processed[TransitionKey.OBSERVATION]
|
||||
|
||||
# Test with only image observation (no state)
|
||||
observation = {OBS_IMAGE: torch.randn(3, 224, 224)}
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert OBS_STATE not in processed[TransitionKey.OBSERVATION]
|
||||
@@ -725,3 +725,264 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
processor_right(transition)
|
||||
|
||||
assert tracking_tokenizer.padding_side_calls[-1] == "right"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_device_detection_cpu():
|
||||
"""Test that tokenized tensors stay on CPU when other tensors are on CPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10)} # CPU tensor
|
||||
action = torch.randn(5) # CPU tensor
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "test task"}
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors are on CPU
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device.type == "cpu"
|
||||
assert attention_mask.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_device_detection_cuda():
|
||||
"""Test that tokenized tensors are moved to CUDA when other tensors are on CUDA."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor
|
||||
action = torch.randn(5).cuda() # CUDA tensor
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "test task"}
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors are on CUDA
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device.type == "cuda"
|
||||
assert attention_mask.device.type == "cuda"
|
||||
assert tokens.device.index == 0 # Should be on same device as input
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
@require_package("transformers")
|
||||
def test_device_detection_multi_gpu():
|
||||
"""Test that tokenized tensors match device in multi-GPU setup."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Test with tensors on cuda:1
|
||||
device = torch.device("cuda:1")
|
||||
observation = {"observation.state": torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "multi gpu test"}
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors are on cuda:1
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device == device
|
||||
assert attention_mask.device == device
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_device_detection_no_tensors():
|
||||
"""Test that tokenized tensors stay on CPU when no other tensors exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with no tensors
|
||||
transition = create_transition(
|
||||
observation={"metadata": {"key": "value"}}, # No tensors
|
||||
complementary_data={"task": "no tensor test"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors are on CPU (default)
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device.type == "cpu"
|
||||
assert attention_mask.device.type == "cpu"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_device_detection_mixed_devices():
|
||||
"""Test device detection when tensors are on different devices (uses first found)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Create transition with mixed devices
|
||||
observation = {
|
||||
"observation.cpu": torch.randn(10), # CPU
|
||||
"observation.cuda": torch.randn(10).cuda(), # CUDA
|
||||
}
|
||||
transition = create_transition(
|
||||
observation=observation, complementary_data={"task": "mixed device test"}
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# The device detection should use the first tensor found
|
||||
# (iteration order depends on dict, but result should be consistent)
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
# Both should be on the same device
|
||||
assert tokens.device == attention_mask.device
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_device_detection_from_action():
|
||||
"""Test that device is detected from action tensor when no observation tensors exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with action on CUDA but no observation tensors
|
||||
observation = {"metadata": {"key": "value"}} # No tensors in observation
|
||||
action = torch.randn(5).cuda()
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "action device test"}
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors match action's device
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device.type == "cuda"
|
||||
assert attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_device_detection_from_complementary_data():
|
||||
"""Test that device is detected from tensors in complementary_data."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with tensor in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"metadata": {"key": "value"}}, # No tensors
|
||||
complementary_data={
|
||||
"task": "comp data test",
|
||||
"index": torch.tensor([42]).cuda(), # Tensor in complementary_data
|
||||
},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors match complementary_data tensor's device
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device.type == "cuda"
|
||||
assert attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_device_detection_preserves_dtype():
|
||||
"""Test that device detection doesn't affect dtype of tokenized tensors."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with float tensor (to test dtype isn't affected)
|
||||
observation = {"observation.state": torch.randn(10, dtype=torch.float16)}
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "dtype test"})
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tokenized tensors have correct dtypes (not affected by input dtype)
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.dtype == torch.long # Should remain long
|
||||
assert attention_mask.dtype == torch.bool # Should be bool (converted in processor)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||
def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
"""Test that TokenizerProcessor works correctly with DeviceProcessor in pipeline."""
|
||||
from lerobot.processor import DeviceProcessor
|
||||
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Create pipeline with TokenizerProcessor then DeviceProcessor
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessor(device="cuda:0")
|
||||
robot_processor = RobotProcessor([tokenizer_processor, device_processor])
|
||||
|
||||
# Start with CPU tensors
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(10)}, # CPU
|
||||
action=torch.randn(5), # CPU
|
||||
complementary_data={"task": "pipeline test"},
|
||||
)
|
||||
|
||||
result = robot_processor(transition)
|
||||
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessor)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Tokenized tensors should also be on CUDA
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
assert tokens.device.type == "cuda"
|
||||
assert attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_simulated_accelerate_scenario():
|
||||
"""Test scenario simulating Accelerate with data already on GPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Simulate Accelerate scenario: batch already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
"observation.state": torch.randn(1, 10).to(device), # Batched, on GPU
|
||||
"observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU
|
||||
}
|
||||
action = torch.randn(1, 5).to(device) # Batched, on GPU
|
||||
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
complementary_data={"task": ["accelerate test"]}, # List for batched task
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Tokenized tensors should match GPU placement
|
||||
tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
assert tokens.device == device
|
||||
assert attention_mask.device == device
|
||||
# MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length)
|
||||
assert tokens.shape == (10,) # MockTokenizer behavior for single string in list
|
||||
assert attention_mask.shape == (10,)
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for VQBeT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default VQBeT configuration for testing."""
|
||||
config = VQBeTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_vqbet_processor_basic():
|
||||
"""Test basic creation of VQBeT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
|
||||
|
||||
def test_vqbet_processor_with_images():
|
||||
"""Test VQBeT processor with image and state observations."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Create test data with images and states
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 7)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_cuda():
|
||||
"""Test VQBeT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_accelerate_scenario():
|
||||
"""Test VQBeT processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_vqbet_processor_multi_gpu():
|
||||
"""Test VQBeT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_vqbet_processor_without_stats():
|
||||
"""Test VQBeT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_vqbet_processor_save_and_load():
|
||||
"""Test saving and loading VQBeT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 7)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_mixed_precision():
|
||||
"""Test VQBeT processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_vqbet_processor_large_batch():
|
||||
"""Test VQBeT processor with large batch sizes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Test with large batch
|
||||
batch_size = 128
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 8),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 8)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (batch_size, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION].shape == (batch_size, 7)
|
||||
|
||||
|
||||
def test_vqbet_processor_sequential_processing():
|
||||
"""Test VQBeT processor with sequential data processing."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_processor(config, stats)
|
||||
|
||||
# Process multiple samples sequentially
|
||||
results = []
|
||||
for _ in range(5):
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
results.append(processed)
|
||||
|
||||
# Check that all results are consistent
|
||||
for result in results:
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert result[TransitionKey.ACTION].shape == (1, 7)
|
||||
Reference in New Issue
Block a user