Compare commits

..

39 Commits

Author SHA1 Message Date
AdilZouitine 15960f0b5e refactor(utils): enhance task handling in add_envs_task function
- Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings.
- Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability.
- Streamlined the observation dictionary handling by ensuring consistent data types for task attributes.
2025-09-10 10:05:43 +02:00
AdilZouitine 8b43339563 debug 2025-09-10 10:05:43 +02:00
AdilZouitine 5dababd21e refactor(eval): remove redundant observation device conversion in rollout function
- Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability.
- This change simplifies the observation handling process, aligning with the preference for clearer solutions.
2025-09-10 10:05:43 +02:00
AdilZouitine cbc46467b3 refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions
- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline.
- Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.
2025-09-10 10:05:43 +02:00
Steven Palma e881fb6678 refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)
* refactor(processor): signature of transform_features

* refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly

* refactor(processor): rename now is only for visual

* refactor(processor): update normalize processor

* refactor(processor): update vanilla processor features

* refactor(processor): feature contract now uses its own enum

* chore(processor): rename renameprocessor

* chore(processor): minor changes

* refactor(processor): add create & change aggregate

* refactor(processor): update aggregate

* refactor(processor): simplify to functions, fix features contracts and rename function

* test(processor): remove to converter tests as now they are very simple

* chore(docs): recover docs joint observations processor

* fix(processor): update RKP

* fix(tests): recv diff test_pipeline

* chore(tests): add docs to test

* chore(processor): leave obs language constant untouched

* fix(processor): correct new shape of feature in crop image processor
2025-09-09 18:27:30 +02:00
Adil Zouitine acf0ba7fb3 refactor(converters): rename _from_tensor to from_tensor_to_numpy for clarity (#1902)
- Updated the function name from _from_tensor to from_tensor_to_numpy to better reflect its purpose of converting PyTorch tensors to numpy arrays or scalars.
- Adjusted all references to the renamed function throughout the codebase to maintain consistency.
- Enhanced the _NormalizationMixin class to reconstruct the stats dictionary from tensor stats using the new function, ensuring compatibility after loading state dicts.
- Added tests to verify the correct reconstruction of stats and functionality of methods dependent on self.stats after loading.
2025-09-09 17:51:47 +02:00
Adil Zouitine a74b90edd1 refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions (#1900)
* refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions

- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline.
- Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.

* refactor(eval): remove redundant observation device conversion in rollout function

- Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability.
- This change simplifies the observation handling process, aligning with the preference for clearer solutions.

* debug

* refactor(utils): enhance task handling in add_envs_task function

- Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings.
- Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability.
- Streamlined the observation dictionary handling by ensuring consistent data types for task attributes.
2025-09-09 17:00:34 +02:00
Steven Palma 846677f9cc Merge branch 'main' into user/azouitine/2025-7-4-convert-codebase-with-pipeline 2025-09-08 22:35:13 +02:00
Steven Palma af9ddcf9a2 chore(docs): update doctrines pipeline files (#1872)
* docs(processor): update docstrings batch_processor

* docs(processor): update docstrings device_processor

* docs(processor): update docstrings tokenizer_processor

* update docstrings processor_act

* update docstrings for pipeline_features

* update docstrings for utils

* update docstring for processor_diffusion

* update docstrings factory

* add docstrings to pi0 processor

* add docstring to pi0fast processor

* add docstring classifier processor

* add docstring to sac processor

* add docstring smolvla processor

* add docstring to tdmpc processor

* add docstring to vqbet processor

* add docstrings to converters

* add docstrings for delta_action_processor

* add docstring to gym action processor

* update hil processor

* add docstring to joint obs processor

* add docstring to migrate_normalize_processor

* update docstrings normalize processor

* update docstring normalize processor

* update docstrings observation processor

* update docstrings rename_processor

* add docstrings robot_kinematic_processor

* cleanup rl comments

* add docstring to train.py

* add docstring to teleoperate.py

* add docstrings to phone_processor.py

* add docstrings to teleop_phone.py

* add docstrings to control_utils.py

* add docstrings to visualization_utils.py

---------

Co-authored-by: Pepijn <pepijn@huggingface.co>
2025-09-08 18:44:15 +02:00
Steven Palma d602e8169c fix(scripts): revert deletion of rs cam config import introduced by #1767 (#1876) 2025-09-08 18:29:39 +02:00
Steven Gong 49baccdccb Disable torque before applying calibration logic (#1889) 2025-09-08 11:38:13 +02:00
Adil Zouitine d32006440c refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880)
* refactor(processors): reorder processor steps for consistency across implementations

- Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep.
- Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability.

* refactor(normalization): remove dtype specification in tensor conversion for adaptation logic

- Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types.
- Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types.
2025-09-08 10:46:35 +02:00
Steven Palma f1cfdfced9 fix(processor): recover type inference for use of processors (#1873) 2025-09-05 11:31:30 +02:00
Gaëlle Lannuzel 6a3d57031a 2 add reachy 2 to updated lerobot (#1767)
* Start adding Reachy 2 (no camera)

* Fix joint shape

* Remove print

* Modify observation_features

* Fix observation state

* Try adding a fake Reachy teleoperator

* Saving test scripts

* Add reachy2camera to cameras

* Add teleop_left camera to observation

* Create test_reachy2_camera.py

* Update utils.py

* Add all rgb cameras

* Future depth work

* Try adding mobile_base velocity

* Update tests

* Update data_acquisition_server.py

* Update with use_external_commands

* Replay

* Usable with or without mobile base

* No need for new isntance

* Use same ip for cameras

* Remove useless imports

* Add resume

* Divide joints in multiple dicts

* Divide joinits into several dicts in teleoperator

* Fix forgotten method call

* Create test_robot_client.py

* Open gripper on start

* Add arguments for cameras

* Modify get_frame() requested size

* Call generate_joints_dict on _init_

* black + isort

* Add reachy2 in imports

* Add reachy2 dependencies

* Add documentation

* Update reachy2.mdx

* Update reachy2.mdx

* Clean files and add types

* Fix type in send_action

* Remove print

* Delete test files

* Clean code

* Update cameras

* Disconnect from camera

* Run pre-commit hooks

* Update pyproject.toml

* Create test_reachy2.py

* Fix generate_joints

* Update test_reachy2.py

* Update send_action test

* Update reachy2_cameras depth + CameraManager

* Update reachy2_camera tests

* Remove useless import and args

* Rename reachy2_teleoperator

* Create test_reachy2_teleoperator.py

* Fix remainging fake_teleoperator

* Remove useless elements

* Mock cameras in test_reachy2

* Delete commented lines

* Add use_present_position to teleoperator

* Add cameras tests

* Add check no part + test

* Use disable_torque_on_disconnect

* Use odometry for vel with present_position

* Update documentation

* Fix vel value type

* Use ensure_safe_goal_position

* Import joints dict from classes

* Update reachy2.mdx

* Update reachy2.mdx

* Update minimal version

* Update minimal version

* fix(tests) fixes for reachy2 tests; removing reachy2 references from the script

* Add reachy2_sdk fake as plugins

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-09-05 11:03:14 +02:00
Justin Huang d74494d92b Allow max_relative_target to be a float (#1837)
* Remove unused max_relative_target for stretch3

* Fix type annotation and allow integer max_relative_target values

* Configure max_relative_target to be floats instead of ints

* Update docs and types to reflect that max_relative_target can be a dict

* Remove unnecessary isinstance check for ints

* Fix typo in name

---------

Co-authored-by: Justin Huang <justin.huang@jpl.nasa.gov>
2025-09-05 09:58:47 +02:00
Adil Zouitine 888a5b6249 refactor(utils): simplify log_rerun_data function (#1864)
* refactor(logging): enhance log_rerun_data to handle observation and action separately

- Updated the `log_rerun_data` function to accept and log observation and action data more clearly, improving readability and maintainability.
- Refactored the `record_loop` and `teleop_loop` functions to extract and pass observation and action data to `log_rerun_data`, ensuring consistent logging format.

* refactor(tests): update test_log_rerun_data to align with log_rerun_data changes

- Modified test cases in `test_visualization_utils.py` to extract and pass observation and action data separately to `log_rerun_data`, improving clarity and consistency with recent function updates.
- Ensured that the tests reflect the new structure of `log_rerun_data` for better maintainability.

* refactor(processors): simplify calls to log_rerun + replace lambda functions with identity_transition

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-09-04 19:25:51 +02:00
Adil Zouitine f247aa0701 refactor(tests): update processor test assertions to reflect new preprocessor and postprocessor names (#1869)
- Changed assertions in multiple processor test files to verify the updated names from "robot_preprocessor" and "robot_postprocessor" to "policy_preprocessor" and "policy_postprocessor" for consistency with recent refactoring.
2025-09-04 17:34:06 +02:00
Adil Zouitine 1ac6a6d3fe refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868)
- Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context.
- Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase.
2025-09-04 17:01:53 +02:00
Steven Palma e698c709d8 fix(deps): use in-house rotation utils over scipy throughout the codebase 2025-09-04 16:44:18 +02:00
Adil Zouitine a988da4789 feat(teleoperation): introduce HasTeleopEvents protocol and enhance teleop event handling (#1866)
- Added the HasTeleopEvents protocol to define a standard for teleoperators that provide control events.
- Implemented a runtime check to ensure teleoperators implement the get_teleop_events() method.
- Updated AddTeleopEventsAsInfoStep to utilize the new protocol, enhancing compatibility with custom teleoperators.
- Improved documentation for clarity on teleoperation event extraction and compatibility with built-in teleoperators.
2025-09-04 16:28:49 +02:00
Adil Zouitine 99963b6968 refactor(dependencies): remove scipy dependency and introduce custom rotation utilities (#1863)
- Removed the scipy dependency from the project to streamline requirements.
- Added a new `rotation.py` module containing a custom `Rotation` class that replicates essential functionalities of `scipy.spatial.transform.Rotation`, allowing for rotation vector, matrix, and quaternion conversions without external dependencies.
- Updated the `robot_kinematic_processor.py` to utilize the new custom rotation utilities.
2025-09-04 16:26:28 +02:00
Adil Zouitine 332ca4ccc5 refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)
- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity.
- Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests.
- Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
2025-09-04 16:22:03 +02:00
Adil Zouitine fc43246942 feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861)
- Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording.
- Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features.
2025-09-04 16:17:31 +02:00
Adil Zouitine 793ad86fc9 refactor(processor): enforce config_filename requirement for HF Hub loading (#1860)
- Updated the DataProcessorPipeline to require a specific config_filename when loading from Hugging Face Hub, enhancing clarity and preventing errors.
- Simplified local path checks and improved error handling for invalid paths.
- Adjusted tests to reflect the new requirement and ensure proper error handling for various loading scenarios.
2025-09-04 10:31:18 +02:00
Adil Zouitine a6dbb65917 chore(processor): add type alias RobotProcessorPipeline and PolicyProcessorPipeline (#1859)
* feat(processor): introduce PolicyProcessorPipeline and RobotProcessorPipeline as type aliases for DataProcessorPipeline

- Added PolicyProcessorPipeline and RobotProcessorPipeline type aliases to enhance clarity and maintainability in the processor module.
- Updated the __all__ list to include the new pipelines for better module export consistency.

* refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline across multiple modules

- Updated all instances of DataProcessorPipeline to PolicyProcessorPipeline in various processor files for consistency and clarity.
- Adjusted function signatures to reflect the new pipeline type, enhancing maintainability and readability.

* refactor(processor): update hotswap_stats function to use PolicyProcessorPipeline

- Changed the parameter name from robot_processor to policy_processor for clarity.
- Ensured consistency with recent updates to the processor module by reflecting the new pipeline type in the function signature.

* refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline in migrate_policy_normalization.py

- Updated the preprocessor and postprocessor to use PolicyProcessorPipeline for consistency with recent changes in the processor module.
- Enhanced clarity and maintainability by aligning with the new pipeline structure.

* refactor(processor): update hotswap_stats to use PolicyProcessorPipeline

- Changed the parameter type in hotswap_stats from DataProcessorPipeline to PolicyProcessorPipeline for consistency with recent updates.
- Enhanced clarity by updating the function documentation to reflect the new pipeline type.

* refactor(processor): replace DataProcessorPipeline with RobotProcessorPipeline across multiple files

- Updated instances of DataProcessorPipeline to RobotProcessorPipeline in evaluate.py, record.py, replay.py, teleoperate.py, and other relevant files for consistency and clarity.
- Adjusted function signatures and variable types to reflect the new pipeline structure, enhancing maintainability and readability.
2025-09-03 19:01:28 +02:00
Steven Palma 6c7169c4af chore(processor): rename teleop_phone variable names (#1858) 2025-09-03 18:42:13 +02:00
Adil Zouitine f125d5e3bf refactor(processor): rename internal device variable for clarity (#1857)
- Changed the internal device variable from `_device` to `tensor_device` for improved readability and consistency.
- Updated references throughout the class to reflect the new variable name.
2025-09-03 18:39:06 +02:00
Steven Palma 75dcfd4886 chore(processor): rename merge_features -> combine_feature_dicts (#1856) 2025-09-03 18:20:35 +02:00
Adil Zouitine ff3cbaa872 refactor(processor): rename internal tokenizer variable for clarity (#1855)
- Changed the internal tokenizer variable name from `_tokenizer` to `input_tokenizer` for improved readability and consistency.
- Updated references throughout the class to reflect the new variable name.
2025-09-03 18:20:12 +02:00
Adil Zouitine ce793cde64 chore(processor): add Step suffix to all processors (#1854)
* refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency

* refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules

* refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency

* refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency

* refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency

* refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency

* refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency

* refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency

* refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency

* refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency

* refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency

* refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency

* refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency

* refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency

* refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency

* refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency

* refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency

* refactor(processor): update config file name in test for RenameProcessorStep consistency
2025-09-03 18:12:11 +02:00
Steven Palma 029c4a9a76 chore(processor): rename converters function names (#1853)
* chore(processor): rename to_transition_teleop_action -> action_to_transition

* chore(processor): rename to_transition_robot_observation -> observation_to_transition

* chore(processor): rename to_output_robot_action -> transition_to_robot_action
2025-09-03 18:08:54 +02:00
Steven Palma d893bf1e30 chore(processor): rename specialized processor -> XYZProcessorStep (#1852) 2025-09-03 17:30:47 +02:00
Steven Palma 8c796b39f5 chore(processor): rename RobotProcessor -> DataProcessorPipeline (#1850) 2025-09-03 17:13:16 +02:00
Adil Zouitine 4ebe482a7e refactor(processors): enhance transform_features method across multiple processors (#1849)
* refactor(processors): enhance transform_features method across multiple processors

- Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features.
- Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others.
- Improved readability and maintainability by following consistent patterns in feature transformation.

* refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor

- Updated action and observation keys to use constants for improved readability and maintainability.
- Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys.
- Enhanced error handling by raising exceptions for missing required components in action and observation processing.
- Removed obsolete code and improved overall structure for better clarity.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(processors): remove unused import in joint_observations_processor

* refactor(processors): simplify transform_features method in delta_action_processor

* refactor(processors): streamline transform_features method in ImageCropResizeProcessor

* refactor(processors): improve error handling and streamline transform_features method in phone_processor

- Raised a ValueError for missing position and rotation in action to enhance error handling.

* refactor(processors): enhance error handling in JointVelocityProcessor

- Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method.

* refactor(processors): simplify transform_features method in robot kinematic processors

* refactor(processors): standardize action keys in phone_processor

* fix(processor): RKP feature obs -> act

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-09-03 16:54:41 +02:00
Steven Palma 2fcc358e98 refactor(processors): add extended api for specialized pipelines (#1848) 2025-09-03 12:28:40 +02:00
Steven Palma b052843f08 refactor(processors): unify import statements by consolidating pipeline imports into the main processor module (#1845) 2025-09-02 18:26:59 +02:00
Steven Palma ebb464c255 refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844) 2025-09-02 17:57:49 +02:00
Steven Palma 2914ae2a96 refactor(processors): add transform_features method to various processors (#1843) 2025-09-02 17:15:01 +02:00
Adil Zouitine 645c87e3a9 refactor(converters): gather converters and refactor the logic (#1833)
* refactor(converters): move batch transition functions to converters module

- Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns.
- Updated references in `RobotProcessor` to use the new location of these functions.
- Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields.
- Removed redundant tests from `pipeline.py` to streamline the test suite.

* refactor(processor): reorganize EnvTransition and TransitionKey definitions

- Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability.
- Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase.

* refactor(converters): rename and update dataset frame conversion functions

- Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming.
- Updated references in `record.py`, `pipeline.py`, and tests to use the new function name.
- Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability.
- Adjusted related tests to ensure correct functionality with the new naming conventions.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(processor): solve conflict artefacts

* refactor(converters): remove unused identity function and update type hints for merge_transitions

* refactor(processor): remove unused identity import and clean up gym_manipulator.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-09-02 15:33:38 +02:00
99 changed files with 8023 additions and 2725 deletions
+2
View File
@@ -35,6 +35,8 @@
title: Koch v1.1
- local: lekiwi
title: LeKiwi
- local: reachy2
title: Reachy 2
title: "Robots"
- sections:
- local: notebooks
+13 -13
View File
@@ -143,27 +143,27 @@ HIL-SERL uses a modular processor pipeline architecture that processes robot obs
The environment processor (`env_processor`) handles incoming observations and environment state:
1. **VanillaObservationProcessor**: Converts raw robot observations into standardized format
2. **JointVelocityProcessor** (optional): Adds joint velocity information to observations
3. **MotorCurrentProcessor** (optional): Adds motor current readings to observations
1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
5. **ImageCropResizeProcessor** (optional): Crops and resizes camera images
6. **TimeLimitProcessor** (optional): Enforces episode time limits
7. **GripperPenaltyProcessor** (optional): Applies penalties for inappropriate gripper usage
8. **RewardClassifierProcessor** (optional): Automated reward detection using vision models
9. **ToBatchProcessor**: Converts data to batch format for neural network processing
10. **DeviceProcessor**: Moves data to the specified compute device (CPU/GPU)
5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
#### Action Processor Pipeline
The action processor (`action_processor`) handles outgoing actions and human interventions:
1. **AddTeleopActionAsComplimentaryData**: Captures teleoperator actions for logging
2. **AddTeleopEventsAsInfo**: Records intervention events and episode control signals
1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
4. **InterventionActionProcessor**: Handles human interventions and episode termination
4. **InterventionActionProcessorStep**: Handles human interventions and episode termination
5. **Inverse Kinematics Pipeline** (when enabled):
- **MapDeltaActionToRobotAction**: Converts delta actions to robot action format
- **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
- **EEBoundsAndSafety**: Enforces workspace safety bounds
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
+288
View File
@@ -0,0 +1,288 @@
# Reachy 2
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
## Teleoperate Reachy 2
Currently, there are two ways to teleoperate Reachy 2:
- Pollen Robotics VR teleoperation (not included in LeRobot).
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
## Reachy 2 Simulation
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
1. Install [Docker Engine](https://docs.docker.com/engine/).
2. Run (for MuJoCo):
```
docker run --rm -it \
--name reachy \
--privileged \
--network host \
--ipc host \
--device-cgroup-rule='c 189:* rwm' \
--group-add audio \
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
-e DISPLAY="$DISPLAY" \
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
-v /dev:/dev \
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
--entrypoint /package/launch.sh \
pollenrobotics/reachy2_core:1.7.5.9_deploy \
start_rviz:=true start_sdk_server:=true mujoco:=true
```
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
>
> ```
> docker run --rm -it \
> --name reachy \
> --privileged \
> --network host \
> --ipc host \
> --device-cgroup-rule='c 189:* rwm' \
> --group-add audio \
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
> -e DISPLAY="$DISPLAY" \
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
> -v /dev:/dev \
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
> --entrypoint /package/launch.sh \
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
> start_rviz:=true start_sdk_server:=true mujoco:=true
> ```
## Setup
### Prerequisites
- On your robot, check the **service images** meet the minimum versions:
- **reachy2-core >= 1.7.5.2**
- **webrtc >= 2.0.1.1**
Then, if you want to use VR teleoperation:
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
Use version **>=v1.2.0**
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
### Install LeRobot
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
Install LeRobot with Reachy 2 dependencies:
```bash
pip install -e ".[reachy2]"
```
### (Optional but recommended) Install pollen_data_acquisition_server
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. Youre free to develop custom clients to manage sessions to your needs.
In your LeRobot environment, install the server from source:
```bash
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
cd pollen_data_acquisition_server
pip install -e .
```
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
## Step 1: Recording
### Get Reachy 2 IP address
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
### Launch recording
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robots motions.
- **Using LeRobots record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, its only for motion control.
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
```bash
python -m pollen_data_acquisition_server.server
```
Then get into the teleoperation application and choose "Data acquisition session".
You can finally setup your session by following the screens displayed.
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
### Option 2: Using lerobot.record
Reachy 2 is fully supported by LeRobots recording features.
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
**Example: start a recording without the mobile base:**
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.id=r2-0000 \
--robot.use_external_commands=true \
--robot.with_mobile_base=false \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
#### Specific Options
**Extended setup overview (all options included):**
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=true \
--robot.with_mobile_base=true \
--robot.with_l_arm=true \
--robot.with_r_arm=true \
--robot.with_neck=true \
--robot.with_antennas=true \
--robot.with_left_teleop_camera=true \
--robot.with_right_teleop_camera=true \
--robot.with_torso_camera=false \
--robot.disable_torque_on_disconnect=false \
--robot.max_relative_target=5.0 \
--teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \
--teleop.use_present_position=false \
--teleop.with_mobile_base=false \
--teleop.with_l_arm=true \
--teleop.with_r_arm=true \
--teleop.with_neck=true \
--teleop.with_antennas=true \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--display_data=true
```
##### `--robot.use_external_commands`
Determine whether LeRobot robot.send_action() sends commands to the robot.
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
##### `--teleop.use_present_position`
Determine whether the teleoperator reads the goal or present position of the robot.
Must be set to true if a compliant Reachy 2 is used to control another one.
##### Use the relevant parts
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
To avoid this, you can exclude specific parts from recording and replay using:
````
--robot.with_<part>=false
```,
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
It determine whether the corresponding part is recorded in the observations. True if not set.
By default, **all parts are recorded**.
The same per-part mechanism is available in `reachy2_teleoperator` as well.
````
--teleop.with\_<part>
```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
Determine whether the corresponding part is recorded in the actions. True if not set.
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
##### Use the relevant cameras
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
```
--robot.with_left_teleop_camera=<true|false>
--robot.with_right_teleop_camera=<true|false>
--robot.with_torso_camera=<true|false>
````
## Step 2: Replay
Make sure the robot is configured with the same parts as the dataset:
```bash
python -m lerobot.replay \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--robot.use_external_commands=false \
--robot.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \
--dataset.episode=0
--display_data=true
````
## Step 3: Train
```bash
python -m lerobot.scripts.train \
--dataset.repo_id=pollen_robotics/record_test \
--policy.type=act \
--output_dir=outputs/train/reachy2_test \
--job_name=reachy2 \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=pollen_robotics/record_test_policy
```
## Step 4: Evaluate
```bash
python -m lerobot.record \
--robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \
--display_data=false \
--dataset.repo_id=pollen_robotics/eval_record_test \
--dataset.single_task="Evaluate reachy2 policy" \
--dataset.num_episodes=10 \
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
```
+19 -18
View File
@@ -16,16 +16,17 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import merge_features
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
identity_transition,
observation_to_transition,
transition_to_action,
)
from lerobot.processor.pipeline import RobotProcessor
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
@@ -65,7 +66,7 @@ kinematics_solver = RobotKinematics(
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
robot_ee_to_joints_processor = RobotProcessorPipeline(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
@@ -74,36 +75,36 @@ robot_ee_to_joints = RobotProcessor(
initial_guess_current_joints=True,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
to_transition=identity_transition,
to_output=transition_to_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor(
robot_joints_to_ee_pose_processor = RobotProcessorPipeline(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=to_transition_robot_observation,
to_output=lambda tr: tr,
to_transition=observation_to_transition,
to_output=identity_transition,
)
# Build dataset action and gripper features
action_ee_and_gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints,
initial_features={},
pipeline=robot_ee_to_joints_processor,
initial_features=create_initial_features(),
use_videos=True,
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
) # Get all ee action features + gripper pos action features
# Build dataset observation features
obs_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
pipeline=robot_joints_to_ee_pose_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
patterns=["observation.state.ee"],
) # Get all ee observation features
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
print("All dataset features: ", dataset_features)
@@ -147,8 +148,8 @@ for episode_idx in range(NUM_EPISODES):
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
robot_action_processor=robot_ee_to_joints,
robot_observation_processor=robot_joints_to_ee_pose,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
dataset.save_episode()
+26 -25
View File
@@ -17,15 +17,16 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import merge_features
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
action_to_transition,
identity_transition,
observation_to_transition,
transition_to_action,
)
from lerobot.processor.pipeline import RobotProcessor
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
@@ -73,7 +74,7 @@ kinematics_solver = RobotKinematics(
)
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor(
phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
@@ -88,12 +89,12 @@ phone_to_robot_ee_pose = RobotProcessor(
max_ee_twist_step_rad=0.50,
),
],
to_transition=to_transition_teleop_action,
to_output=lambda tr: tr,
to_transition=action_to_transition,
to_output=identity_transition,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
robot_ee_to_joints_processor = RobotProcessorPipeline(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
@@ -105,31 +106,31 @@ robot_ee_to_joints = RobotProcessor(
speed_factor=20.0,
),
],
to_transition=lambda tr: tr,
to_output=to_output_robot_action,
to_transition=identity_transition,
to_output=transition_to_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor(
robot_joints_to_ee_pose = RobotProcessorPipeline(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=to_transition_robot_observation,
to_output=lambda tr: tr,
to_transition=observation_to_transition,
to_output=identity_transition,
)
# Build dataset ee action features
action_ee = aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose,
initial_features=phone.action_features,
pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features),
use_videos=True,
patterns=["action.ee"],
)
# Get gripper pos action features
gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints,
initial_features={},
pipeline=robot_ee_to_joints_processor,
initial_features=create_initial_features(),
use_videos=True,
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
)
@@ -137,12 +138,12 @@ gripper = aggregate_pipeline_dataset_features(
# Build dataset ee observation features
observation_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
patterns=["observation.state.ee"],
)
dataset_features = merge_features(action_ee, gripper, observation_ee)
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
print("All dataset features: ", dataset_features)
@@ -177,8 +178,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose,
robot_action_processor=robot_ee_to_joints,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
@@ -193,8 +194,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose,
robot_action_processor=robot_ee_to_joints,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
+7 -7
View File
@@ -19,8 +19,8 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
@@ -50,7 +50,7 @@ kinematics_solver = RobotKinematics(
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor(
robot_ee_to_joints_processor = RobotProcessorPipeline(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
@@ -59,11 +59,11 @@ robot_ee_to_joints = RobotProcessor(
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=to_transition_teleop_action,
to_output=to_output_robot_action,
to_transition=action_to_transition,
to_output=transition_to_action,
)
robot_ee_to_joints.reset()
robot_ee_to_joints_processor.reset()
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(dataset.num_frames):
@@ -73,7 +73,7 @@ for idx in range(dataset.num_frames):
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
}
joint_action = robot_ee_to_joints(ee_action)
joint_action = robot_ee_to_joints_processor(ee_action)
action_sent = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
+6 -6
View File
@@ -16,8 +16,8 @@
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
@@ -49,7 +49,7 @@ kinematics_solver = RobotKinematics(
)
# Build pipeline to convert phone action to ee pose action to joint action
phone_to_robot_joints = RobotProcessor(
phone_to_robot_joints_processor = RobotProcessorPipeline(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
@@ -72,8 +72,8 @@ phone_to_robot_joints = RobotProcessor(
speed_factor=20.0,
),
],
to_transition=to_transition_teleop_action,
to_output=to_output_robot_action,
to_transition=action_to_transition,
to_output=transition_to_action,
)
robot.connect()
@@ -85,7 +85,7 @@ while True:
phone_obs = teleop_device.get_action()
# Phone -> EE pose -> Joints transition
joint_action = phone_to_robot_joints(phone_obs)
joint_action = phone_to_robot_joints_processor(phone_obs)
if joint_action:
robot.send_action(joint_action)
+2 -1
View File
@@ -73,7 +73,6 @@ dependencies = [
"pynput>=1.7.7",
"pyserial>=3.5",
"wandb>=0.20.0",
"scipy>=1.15.2",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
@@ -107,6 +106,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
reachy2 = ["reachy2_sdk>=1.0.14"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
@@ -143,6 +143,7 @@ all = [
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",
@@ -0,0 +1,16 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2_camera import Reachy2CameraConfig
from .reachy2_camera import Reachy2Camera
@@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
class Reachy2CameraConfig(CameraConfig):
"""Configuration class for Reachy 2 camera devices.
This class provides configuration options for Reachy 2 cameras,
supporting both the teleop and depth cameras. It includes settings
for resolution, frame rate, color mode, and the selection of the cameras.
Example configurations:
```python
# Basic configurations
Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address="192.168.0.200", # IP address of the robot
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
) # Left teleop camera, 640x480 @ 15FPS
```
Attributes:
name: Name of the camera device. Can be "teleop" or "depth".
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
fps: Requested frames per second for the color stream.
width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
ip_address: IP address of the robot. Defaults to "localhost".
port: Port number for the camera server. Defaults to 50065.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
"""
name: str
image_type: str
color_mode: ColorMode = ColorMode.RGB
ip_address: str | None = "localhost"
port: int = 50065
# use_depth: bool = False
def __post_init__(self):
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
self.name == "depth" and self.image_type not in ["rgb", "depth"]
):
raise ValueError(
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
@@ -0,0 +1,288 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
"""
import logging
import os
import platform
import time
from threading import Event, Lock, Thread
from typing import Any
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.errors import DeviceNotConnectedError
from ..camera import Camera
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
logger = logging.getLogger(__name__)
class Reachy2Camera(Camera):
"""
Manages Reachy 2 camera using Reachy 2 CameraManager.
This class provides a high-level interface to connect to, configure, and read
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
frame reading.
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
type (e.g., "left") to be specified in the configuration.
The camera's default settings (FPS, resolution, color mode) are used unless
overridden in the configuration.
"""
def __init__(self, config: Reachy2CameraConfig):
"""
Initializes the Reachy2Camera instance.
Args:
config: The configuration settings for the camera.
"""
super().__init__(config)
self.config = config
self.fps = config.fps
self.color_mode = config.color_mode
self.cam_manager: CameraManager | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
@property
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
elif self.config.name == "depth":
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True):
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
self.cam_manager.initialize_cameras()
logger.info(f"{self} connected.")
@staticmethod
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
"""
Detects available Reachy 2 cameras.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'name', 'stereo',
and the default profile properties (width, height, fps).
"""
initialized_cameras = []
camera_manager = CameraManager(host=ip_address, port=port)
for camera in [camera_manager.teleop, camera_manager.depth]:
if camera is None:
continue
height, width, _, _, _, _, _ = camera.get_parameters()
camera_info = {
"name": camera._cam_info.name,
"stereo": camera._cam_info.stereo,
"default_profile": {
"width": width,
"height": height,
"fps": 30,
},
}
initialized_cameras.append(camera_info)
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
This is a blocking call.
Args:
color_mode (Optional[ColorMode]): If specified, overrides the default
color mode (`self.color_mode`) for this read operation (e.g.,
request RGB even if default is BGR).
Returns:
np.ndarray: The captured frame as a NumPy array in the format
(height, width, channels), using the specified or default
color mode and applying any configured rotation.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter()
frame = None
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
else:
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.config.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8)
if self.config.color_mode == "rgb":
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
read_duration_ms = (time.perf_counter() - start_time) * 1e3
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
return frame
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
This method retrieves the most recent frame captured by the background
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame
to become available. Defaults to 200ms (0.2 seconds).
Returns:
np.ndarray: The latest captured frame as a NumPy array in the format
(height, width, channels), processed according to configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
return frame
def disconnect(self):
"""
Stops the background read thread (if running).
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected and self.thread is None:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.cam_manager is not None:
self.cam_manager.disconnect()
logger.info(f"{self} disconnected.")
+7 -1
View File
@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
from .realsense.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
elif cfg.type == "reachy2_camera":
from .reachy2_camera.reachy2_camera import Reachy2Camera
cameras[key] = Reachy2Camera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
return cameras
+5
View File
@@ -27,6 +27,11 @@ class FeatureType(str, Enum):
LANGUAGE = "LANGUAGE"
class PipelineFeatureType(str, Enum):
ACTION = "ACTION"
OBSERVATION = "OBSERVATION"
class NormalizationMode(str, Enum):
MIN_MAX = "MIN_MAX"
MEAN_STD = "MEAN_STD"
+2 -2
View File
@@ -45,8 +45,8 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
if "LEROBOT_HOME" in os.environ:
raise ValueError(
+106 -60
View File
@@ -12,84 +12,130 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.types import PipelineFeatureType
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor import DataProcessorPipeline
def create_initial_features(
action: dict[str, Any] | None, observation: dict[str, Any] | None
) -> dict[PipelineFeatureType, dict[str, Any]]:
"""
Creates the initial features dict for the dataset from action and observation specs.
Args:
action: A dictionary of action feature names to their types/shapes.
observation: A dictionary of observation feature names to their types/shapes.
Returns:
The initial features dictionary structured by PipelineFeatureType.
"""
features = {PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}}
if action:
features[PipelineFeatureType.ACTION] = action
if observation:
features[PipelineFeatureType.OBSERVATION] = observation
return features
# Helper to filter state/action keys based on regex patterns.
def should_keep(key: str, patterns: tuple[str]) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
for prefix in prefixes_to_strip:
if key.startswith(prefix):
return key[len(prefix) :]
return key
# Define prefixes to strip from feature keys for clean names.
# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms.
PREFIXES_TO_STRIP = tuple(
f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1])
)
def aggregate_pipeline_dataset_features(
pipeline: RobotProcessor,
initial_features: dict[str, Any],
pipeline: DataProcessorPipeline,
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates the pipeline's features and returns a features dict ready for the dataset,
filtered to only those keys matching any of the given patterns (for action/state only).
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
- `use_videos`: whether to treat image features as video streams
- `patterns`: regexes to filter action & state features; images are included
whenever use_videos=True, regardless of patterns.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `use_videos` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: If False, image features are excluded.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
Returns:
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
"""
import re
# Gather everything the pipeline features specifies, seeded with hardware cams:
all_features = pipeline.transform_features(initial_features)
# Helper to decide which action/state keys survive the `patterns` filter:
def keep(key: str) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
# Intermediate storage for categorized and filtered features.
processed_features: dict[str, dict[str, Any]] = {
"action": {},
"observation": {},
}
images_token = OBS_IMAGES.split(".")[-1]
# Start with hardware dict, injecting initial cameras if videos are ON:
hw: dict[str, dict[str, Any]] = {}
if use_videos:
cams = {
name: shape
for name, shape in initial_features.items()
if isinstance(shape, tuple) and len(shape) == 3
}
if cams:
hw["observation"] = dict(cams)
# Go over every feature from the pipeline and merge:
for full_key, ty in all_features.items():
if full_key.startswith(f"{ACTION}."):
# action.<feat>
if not keep(full_key):
continue
name = full_key[len(f"{ACTION}.") :]
hw.setdefault(ACTION, {})[name] = ty
elif full_key.startswith(f"{OBS_STATE}."):
# observation.state.<feat>
if not keep(full_key):
continue
name = full_key[len(f"{OBS_STATE}.") :]
hw.setdefault("observation", {})[name] = ty
elif full_key.startswith(f"{OBS_IMAGES}."):
# observation.images.<cam>
# images obey ONLY the use_videos flag, not patterns
if not use_videos:
continue
name = full_key[len(f"{OBS_IMAGES}.") :]
hw.setdefault("observation", {})[name] = ty
else:
# anything else (e.g. policy-only features) is ignored here
# Iterate through all features transformed by the pipeline.
for ptype, feats in all_features.items():
if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]:
continue
out: dict[str, dict] = {}
if ACTION in hw:
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
if "observation" in hw:
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
for key, value in feats.items():
# 1. Categorize the feature.
is_action = ptype == PipelineFeatureType.ACTION
# Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3.
# All other observations are treated as state.
is_image = not is_action and (
(isinstance(value, tuple) and len(value) == 3)
or (
key.startswith(f"{OBS_IMAGES}.")
or key.startswith(f"{images_token}.")
or f".{images_token}." in key
)
)
return out
# 2. Apply filtering rules.
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, patterns):
continue
# 3. Add the feature to the appropriate group with a clean name.
name = strip_prefix(key, PREFIXES_TO_STRIP)
if is_action:
processed_features["action"][name] = value
else:
processed_features["observation"][name] = value
# Convert the processed features into the final dataset format.
dataset_features = {}
if processed_features["action"]:
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
if processed_features["observation"]:
dataset_features.update(
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
)
return dataset_features
+541 -45
View File
@@ -75,13 +75,20 @@ DEFAULT_FEATURES = {
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
"""Flatten a nested dictionary by joining keys with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
Example:
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
>>> print(flatten_dict(dct))
{'a/b': 1, 'a/c/d': 2, 'e': 3}
Args:
d (dict): The dictionary to flatten.
parent_key (str): The base key to prepend to the keys in this level.
sep (str): The separator to use between keys.
Returns:
dict: A flattened dictionary.
"""
items = []
for k, v in d.items():
@@ -94,6 +101,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
def unflatten_dict(d: dict, sep: str = "/") -> dict:
"""Unflatten a dictionary with delimited keys into a nested dictionary.
Example:
>>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
>>> print(unflatten_dict(flat_dct))
{'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
Args:
d (dict): A dictionary with flattened keys.
sep (str): The separator used in the keys.
Returns:
dict: A nested dictionary.
"""
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@@ -107,6 +128,16 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
"""Access an item in a nested dictionary using a flattened key.
Args:
obj (DictLike): The nested dictionary-like object.
flattened_key (str): A key with parts separated by `sep`.
sep (str): The separator used in the flattened key.
Returns:
Any: The value from the nested dictionary.
"""
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
@@ -119,6 +150,19 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible.
Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types.
Args:
stats (dict): A dictionary that may contain non-serializable numeric types.
Returns:
dict: A dictionary with all values converted to JSON-serializable types.
Raises:
NotImplementedError: If a value has an unsupported type.
"""
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
@@ -133,6 +177,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
"""Embed image bytes into the dataset table before saving to Parquet.
This function prepares a Hugging Face dataset for serialization by converting
image objects into an embedded format that can be stored in Arrow/Parquet.
Args:
dataset (datasets.Dataset): The input dataset, possibly containing image features.
Returns:
datasets.Dataset: The dataset with images embedded in the table storage.
"""
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
@@ -142,38 +197,94 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
def load_json(fpath: Path) -> Any:
"""Load data from a JSON file.
Args:
fpath (Path): Path to the JSON file.
Returns:
Any: The data loaded from the JSON file.
"""
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
"""Write data to a JSON file.
Creates parent directories if they don't exist.
Args:
data (dict): The dictionary to write.
fpath (Path): The path to the output JSON file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def load_jsonlines(fpath: Path) -> list[Any]:
"""Load data from a JSON Lines file.
Args:
fpath (Path): Path to the JSON Lines file.
Returns:
list[Any]: A list of objects loaded from the file.
"""
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
"""Write a list of dictionaries to a JSON Lines file.
Creates parent directories if they don't exist.
Args:
data (dict): The list of dictionaries to write.
fpath (Path): The path to the output JSON Lines file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
"""Append a dictionary to a JSON Lines file.
Creates parent directories if they don't exist.
Args:
data (dict): The dictionary to append.
fpath (Path): The path to the JSON Lines file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def write_info(info: dict, local_dir: Path):
"""Write dataset info metadata to its standard file path.
Args:
info (dict): The dataset information dictionary.
local_dir (Path): The root directory of the dataset.
"""
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
"""Load dataset info metadata from its standard file path.
Also converts shape lists to tuples for consistency.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: The dataset information dictionary.
"""
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
@@ -181,16 +292,40 @@ def load_info(local_dir: Path) -> dict:
def write_stats(stats: dict, local_dir: Path):
"""Serialize and write dataset statistics to their standard file path.
Args:
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
local_dir (Path): The root directory of the dataset.
"""
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
Args:
stats (dict): The statistics dictionary.
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
"""Load dataset statistics and cast numerical values to numpy arrays.
Returns None if the stats file doesn't exist.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
A dictionary of statistics or None if the file is not found.
"""
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
@@ -198,6 +333,13 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
def write_task(task_index: int, task: dict, local_dir: Path):
"""Write a single task to the tasks metadata file.
Args:
task_index (int): The index of the task.
task (dict): The task description dictionary.
local_dir (Path): The root directory of the dataset.
"""
task_dict = {
"task_index": task_index,
"task": task,
@@ -206,6 +348,16 @@ def write_task(task_index: int, task: dict, local_dir: Path):
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
"""Load tasks from the tasks metadata file.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
A tuple containing:
- A dictionary mapping task index to task description.
- A dictionary mapping task description to task index.
"""
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
@@ -213,15 +365,36 @@ def load_tasks(local_dir: Path) -> tuple[dict, dict]:
def write_episode(episode: dict, local_dir: Path):
"""Write a single episode's metadata to the episodes metadata file.
Args:
episode (dict): The episode metadata dictionary.
local_dir (Path): The root directory of the dataset.
"""
append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict:
"""Load episode metadata from the episodes metadata file.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: A dictionary mapping episode index to episode metadata.
"""
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
"""Write statistics for a single episode to the episode stats file.
Args:
episode_index (int): The index of the episode.
episode_stats (dict): The statistics for the episode.
local_dir (Path): The root directory of the dataset.
"""
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
@@ -229,6 +402,14 @@ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path
def load_episodes_stats(local_dir: Path) -> dict:
"""Load per-episode statistics from the episode stats file.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: A dictionary mapping episode index to its statistics dictionary.
"""
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
@@ -239,12 +420,35 @@ def load_episodes_stats(local_dir: Path) -> dict:
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
"""Create a per-episode stats dictionary from a global stats dictionary.
This is used for backward compatibility with older datasets that only had global stats.
Args:
stats (dict): The global dataset statistics.
episodes (list[int]): A list of episode indices.
Returns:
dict: A dictionary mapping each episode index to the global stats.
"""
return dict.fromkeys(episodes, stats)
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
"""Load an image from a file into a numpy array.
Args:
fpath (str | Path): Path to the image file.
dtype (np.dtype): The desired data type of the output array. If floating,
pixels are scaled to [0, 1].
channel_first (bool): If True, converts the image to (C, H, W) format.
Otherwise, it remains in (H, W, C) format.
Returns:
np.ndarray: The image as a numpy array.
"""
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
@@ -255,10 +459,19 @@ def load_image_as_numpy(
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
@@ -273,6 +486,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
def is_valid_version(version: str) -> bool:
"""Check if a string is a valid PEP 440 version.
Args:
version (str): The version string to check.
Returns:
bool: True if the version string is valid, False otherwise.
"""
try:
packaging.version.parse(version)
return True
@@ -286,6 +507,18 @@ def check_version_compatibility(
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None:
"""Check for version compatibility between a dataset and the current codebase.
Args:
repo_id (str): The repository ID for logging purposes.
version_to_check (str | packaging.version.Version): The version of the dataset.
current_version (str | packaging.version.Version): The current version of the codebase.
enforce_breaking_major (bool): If True, raise an error on major version mismatch.
Raises:
BackwardCompatibilityError: If the dataset version is from a newer, incompatible
major version of the codebase.
"""
v_check = (
packaging.version.parse(version_to_check)
if not isinstance(version_to_check, packaging.version.Version)
@@ -303,7 +536,14 @@ def check_version_compatibility(
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
"""Return available valid versions (branches and tags) on a given Hub repo.
Args:
repo_id (str): The repository ID on the Hugging Face Hub.
Returns:
list[packaging.version.Version]: A list of valid versions found.
"""
api = HfApi()
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
@@ -316,9 +556,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
"""
Returns the version if available on repo or the latest compatible one.
Otherwise, will throw a `CompatibilityError`.
"""Return the specified version if available on repo, or the latest compatible one.
If the exact version is not found, it looks for the latest version with the
same major version number that is less than or equal to the target minor version.
Args:
repo_id (str): The repository ID on the Hugging Face Hub.
version (str | packaging.version.Version): The target version.
Returns:
str: The safe version string (e.g., "v1.2.3") to use as a revision.
Raises:
RevisionNotFoundError: If the repo has no version tags.
BackwardCompatibilityError: If only older major versions are available.
ForwardCompatibilityError: If only newer major versions are available.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
@@ -360,6 +613,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
def get_hf_features_from_features(features: dict) -> datasets.Features:
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
Args:
features (dict): A LeRobot-style feature dictionary.
Returns:
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
Raises:
ValueError: If a feature has an unsupported shape.
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
@@ -387,6 +651,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
def _validate_feature_names(features: dict[str, dict]) -> None:
"""Validate that feature names do not contain invalid characters.
Args:
features (dict): The LeRobot features dictionary.
Raises:
ValueError: If any feature name contains '/'.
"""
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
@@ -395,6 +667,22 @@ def _validate_feature_names(features: dict[str, dict]) -> None:
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature
specification.
Args:
hw_features (dict): Dictionary mapping feature names to their type (float for
joints) or shape (tuple for images).
prefix (str): The prefix to add to the feature keys (e.g., "observation"
or "action").
use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns:
dict: A LeRobot features dictionary.
"""
features = {}
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
@@ -427,6 +715,20 @@ def hw_to_dataset_features(
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
"""Construct a single data frame from raw values based on dataset features.
A "frame" is a dictionary containing all the data for a single timestep,
formatted as numpy arrays according to the feature specification.
Args:
ds_features (dict): The LeRobot dataset features dictionary.
values (dict): A dictionary of raw values from the hardware/environment.
prefix (str): The prefix to filter features by (e.g., "observation"
or "action").
Returns:
dict: A dictionary representing a single frame of data.
"""
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
@@ -440,6 +742,21 @@ def build_dataset_frame(
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert dataset features to policy features.
This function transforms the dataset's feature specification into a format
that a policy can use, classifying features by type (e.g., visual, state,
action) and ensuring correct shapes (e.g., channel-first for images).
Args:
features (dict): The LeRobot dataset features dictionary.
Returns:
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
Raises:
ValueError: If an image feature does not have a 3D shape.
"""
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
@@ -470,12 +787,20 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
return policy_features
def merge_features(*dicts: dict) -> dict:
"""
Merge LeRobot grouped feature dicts.
def combine_feature_dicts(*dicts: dict) -> dict:
"""Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (observation.images.*), last one wins (if they are identical).
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
Args:
*dicts: A variable number of LeRobot feature dictionaries to merge.
Returns:
dict: A single merged feature dictionary.
Raises:
ValueError: If there's a dtype mismatch for a feature being merged.
"""
out: dict = {}
for d in dicts:
@@ -521,6 +846,18 @@ def create_empty_dataset_info(
use_videos: bool,
robot_type: str | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
Args:
codebase_version (str): The version of the LeRobot codebase.
fps (int): The frames per second of the data.
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
Returns:
dict: A dictionary with the initial dataset metadata.
"""
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
@@ -541,6 +878,18 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
"""Calculate the start and end indices for each episode in a flattened dataset.
Args:
episode_dicts (dict): A dictionary mapping episode index to episode metadata,
which must contain a "length" key.
episodes (list[int] | None): An optional list of episode indices to consider.
If None, all episodes are used.
Returns:
dict: A dictionary with "from" and "to" keys, containing torch tensors
with the start and end indices for each episode.
"""
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@@ -560,16 +909,19 @@ def check_timestamps_sync(
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
to account for possible numerical error.
"""Check if timestamps are separated by (1/fps) +/- tolerance.
This check ensures that consecutive timestamps within an episode are spaced
correctly, accounting for possible numerical errors. It ignores the boundaries
between episodes.
Args:
timestamps (np.ndarray): Array of timestamps in seconds.
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
episode_data_index (dict): A dictionary that includes 'to',
which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
fps (int): Frames per second. Used to check the expected difference between
consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
@@ -577,7 +929,8 @@ def check_timestamps_sync(
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
ValueError: If `timestamps` and `episode_indices` shapes do not match, or if
the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
@@ -628,9 +981,23 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
This ensures that adding these delta timestamps to any existing timestamp in
the dataset will result in a value that aligns with the dataset's frame rate.
Args:
delta_timestamps (dict): A dictionary where values are lists of time
deltas in seconds.
fps (int): The frames per second of the dataset.
tolerance_s (float): The allowed tolerance in seconds.
raise_value_error (bool): If True, raises an error on failure.
Returns:
bool: True if all deltas are valid, False otherwise.
Raises:
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
@@ -656,6 +1023,15 @@ def check_delta_timestamps(
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
"""Convert delta timestamps in seconds to delta indices in frames.
Args:
delta_timestamps (dict): A dictionary of time deltas in seconds.
fps (int): The frames per second of the dataset.
Returns:
dict: A dictionary of frame delta indices.
"""
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
@@ -664,9 +1040,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
"""Create a dataloader-safe cyclical iterator.
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
This is an equivalent of `itertools.cycle` but is safe for use with
PyTorch DataLoaders with multiple workers.
See https://github.com/pytorch/pytorch/issues/23900 for details.
Args:
iterable: The iterable to cycle over.
Yields:
Items from the iterable, restarting from the beginning when exhausted.
"""
iterator = iter(iterable)
while True:
@@ -677,8 +1061,14 @@ def cycle(iterable):
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""Create a branch on an existing Hugging Face repo.
Deletes the branch if it already exists before creating it.
Args:
repo_id (str): The ID of the repository.
branch (str): The name of the branch to create.
repo_type (str | None): The type of the repository (e.g., "dataset").
"""
api = HfApi()
@@ -696,9 +1086,20 @@ def create_lerobot_dataset_card(
dataset_info: dict | None = None,
**kwargs,
) -> DatasetCard:
"""
Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md.
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""Create a `DatasetCard` for a LeRobot dataset.
Keyword arguments are used to replace values in the card template.
Note: If specified, `license` must be a valid license identifier from
https://huggingface.co/docs/hub/repositories-licenses.
Args:
tags (list | None): A list of tags to add to the dataset card.
dataset_info (dict | None): The dataset's info dictionary, which will
be displayed on the card.
**kwargs: Additional keyword arguments to populate the card template.
Returns:
DatasetCard: The generated dataset card object.
"""
card_tags = ["LeRobot"]
@@ -730,19 +1131,16 @@ def create_lerobot_dataset_card(
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
"""A namespace object that supports both dictionary-like iteration and dot notation.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
This class extends `SimpleNamespace` to provide dictionary-style iteration,
access to items via brackets (`obj["key"]`), and dictionary-like methods
(`items()`, `keys()`, `values()`). Nested dictionaries are recursively
converted to `IterableNamespace` objects.
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
dictionary (dict, optional): A dictionary to initialize the namespace with.
**kwargs: Additional keyword arguments to initialize the namespace.
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
@@ -756,10 +1154,16 @@ class IterableNamespace(SimpleNamespace):
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
details: <__main__.IterableNamespace object at ...>
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
"""Initialize the IterableNamespace.
Args:
dictionary (dict, optional): Dictionary to populate the namespace.
**kwargs: Keyword arguments to populate the namespace.
"""
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
@@ -769,22 +1173,46 @@ class IterableNamespace(SimpleNamespace):
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
"""Return an iterator over the keys of the namespace."""
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
"""Allow bracket-style access to attributes.
Args:
key (str): The name of the attribute.
Returns:
Any: The value of the attribute.
"""
return vars(self)[key]
def items(self):
"""Return a view of the namespace's (key, value) pairs."""
return vars(self).items()
def values(self):
"""Return a view of the namespace's values."""
return vars(self).values()
def keys(self):
"""Return a view of the namespace's keys."""
return vars(self).keys()
def validate_frame(frame: dict, features: dict):
"""Validate a single data frame against the dataset's feature specification.
Checks for missing/extra features, and validates the dtype and shape of each
provided feature.
Args:
frame (dict): The data frame to validate.
features (dict): The LeRobot features dictionary for the dataset.
Raises:
ValueError: If the frame does not match the feature specification.
"""
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
@@ -799,6 +1227,15 @@ def validate_frame(frame: dict, features: dict):
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
"""Check for missing or extra features in a frame.
Args:
actual_features (set[str]): The set of feature names present in the frame.
expected_features (set[str]): The set of feature names expected in the frame.
Returns:
str: An error message string if there's a mismatch, otherwise an empty string.
"""
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
@@ -814,6 +1251,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
"""Validate the dtype and shape of a single feature's value.
Args:
name (str): The name of the feature.
feature (dict): The feature specification from the LeRobot features dictionary.
value: The value of the feature to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
Raises:
NotImplementedError: If the feature dtype is not supported for validation.
"""
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@@ -829,6 +1279,17 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
"""Validate a feature that is expected to be a numpy array.
Args:
name (str): The name of the feature.
expected_dtype (str): The expected numpy dtype as a string.
expected_shape (list[int]): The expected shape.
value (np.ndarray): The numpy array to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@@ -846,6 +1307,18 @@ def validate_feature_numpy_array(
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
"""Validate a feature that is expected to be an image or video frame.
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W).
value: The image data to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@@ -862,12 +1335,35 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
def validate_feature_string(name: str, value: str):
"""Validate a feature that is expected to be a string.
Args:
name (str): The name of the feature.
value (str): The value to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
"""Validate the episode buffer before it's written to disk.
Ensures the buffer has the required keys, contains at least one frame, and
has features consistent with the dataset's specification.
Args:
episode_buffer (dict): The buffer containing data for a single episode.
total_episodes (int): The current total number of episodes in the dataset.
features (dict): The LeRobot features dictionary for the dataset.
Raises:
ValueError: If the buffer is invalid.
NotImplementedError: If the episode index is manually set and doesn't match.
"""
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
+22 -2
View File
@@ -127,9 +127,29 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
task_result = env.call("task_description")
if isinstance(task_result, tuple):
task_result = list(task_result)
if not isinstance(task_result, list):
raise TypeError(f"Expected task_description to return a list, got {type(task_result)}")
if not all(isinstance(item, str) for item in task_result):
raise TypeError("All items in task_description result must be strings")
observation["task"] = task_result
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
task_result = env.call("task")
if isinstance(task_result, tuple):
task_result = list(task_result)
if not isinstance(task_result, list):
raise TypeError(f"Expected task to return a list, got {type(task_result)}")
if not all(isinstance(item, str) for item in task_result):
raise TypeError("All items in task result must be strings")
observation["task"] = task_result
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
+37 -18
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
@@ -33,38 +33,57 @@ def make_act_pre_post_processors(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Creates the pre- and post-processing pipelines for the ACT policy.
The pre-processing pipeline handles normalization, batching, and device placement for the model inputs.
The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU.
Args:
config (ACTConfig): The ACT policy configuration object.
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset
statistics (e.g., mean and std) used for normalization. Defaults to None.
preprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
preprocessor pipeline's constructor. Defaults to None.
postprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
postprocessor pipeline's constructor. Defaults to None.
Returns:
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: A tuple containing the
pre-processor pipeline and the post-processor pipeline.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -16,16 +16,16 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
@@ -34,37 +34,63 @@ def make_diffusion_pre_post_processors(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for a diffusion policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features (if a `rename_map` is provided in `preprocessor_kwargs`).
2. Normalizing the input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving the data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving the data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the diffusion policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
preprocessor_kwargs: Additional keyword arguments
for the pre-processor pipeline. Defaults to an empty dictionary.
postprocessor_kwargs: Additional keyword arguments
for the post-processor pipeline. Defaults to an empty dictionary.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+92 -31
View File
@@ -24,6 +24,7 @@ from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
@@ -38,11 +39,26 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import ProcessorKwargs, RobotProcessor
from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
"""
Retrieves a policy class by its registered name.
This function uses dynamic imports to avoid loading all policy classes into memory
at once, improving startup time and reducing dependencies.
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
Returns:
The policy class corresponding to the given name.
Raises:
NotImplementedError: If the policy name is not recognized.
"""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -84,6 +100,24 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
"""
Instantiates a policy configuration object based on the policy type.
This factory function simplifies the creation of policy configuration objects by
mapping a string identifier to the corresponding config class.
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
"reward_classifier".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
An instance of a `PreTrainedConfig` subclass.
Raises:
ValueError: If the `policy_type` is not recognized.
"""
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
@@ -107,7 +141,21 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
class ProcessorConfigKwargs(TypedDict, total=False):
"""Keyword arguments for the processor config."""
"""
A TypedDict defining the keyword arguments for processor configuration.
This provides type hints for the optional arguments passed to `make_pre_post_processors`,
improving code clarity and enabling static analysis.
Attributes:
preprocessor_config_filename: The filename for the preprocessor configuration.
postprocessor_config_filename: The filename for the postprocessor configuration.
preprocessor_overrides: A dictionary of overrides for the preprocessor configuration.
postprocessor_overrides: A dictionary of overrides for the postprocessor configuration.
dataset_stats: Dataset statistics for normalization.
preprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
postprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
"""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
@@ -122,23 +170,28 @@ def make_pre_post_processors(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[RobotProcessor, RobotProcessor]:
"""Make a processor instance for a given policy type.
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Create or load pre- and post-processor pipelines for a given policy.
This function creates the appropriate processor configuration based on the policy type.
Each policy type has its own processor with specific preprocessing steps.
This function acts as a factory. It can either load existing processor pipelines
from a pretrained path or create new ones from scratch based on the policy
configuration. Each policy type has a dedicated factory function for its
processors (e.g., `make_tdmpc_pre_post_processors`).
Args:
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
the processor from this path instead of creating a new one.
**kwargs: Additional keyword arguments passed to the processor creation.
policy_cfg: The configuration of the policy for which to create processors.
pretrained_path: An optional path to load pretrained processor pipelines from.
If provided, pipelines are loaded from this path.
**kwargs: Keyword arguments for processor configuration, as defined in
`ProcessorConfigKwargs`.
Returns:
Tuple of (input_processor, output_processor) for the policy.
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
Raises:
NotImplementedError: If the policy type doesn't have a processor implemented.
NotImplementedError: If a processor factory is not implemented for the given
policy configuration type.
"""
if pretrained_path:
# Extract preprocessor and postprocessor kwargs
@@ -146,16 +199,20 @@ def make_pre_post_processors(
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
return (
RobotProcessor.from_pretrained(
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=preprocessor_kwargs.get("to_transition"),
to_output=preprocessor_kwargs.get("to_output"),
),
RobotProcessor.from_pretrained(
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=postprocessor_kwargs.get("to_transition"),
to_output=postprocessor_kwargs.get("to_output"),
@@ -264,25 +321,29 @@ def make_policy(
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
"""
Instantiate a policy model.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
This factory function handles the logic of creating a policy, which requires
determining the input and output feature shapes. These shapes can be derived
either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function
can either initialize a new policy from scratch or load a pretrained one.
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is
set, the policy will be loaded with weights from that path.
ds_meta: Dataset metadata used to infer feature shapes and types. Also provides
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
Returns:
PreTrainedPolicy: _description_
An instantiated and device-placed policy model.
Raises:
ValueError: If both or neither of `ds_meta` and `env_cfg` are provided.
NotImplementedError: If attempting to use an unsupported policy-backend
combination (e.g., VQBeT with 'mps').
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
+81 -30
View File
@@ -17,32 +17,45 @@
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.rename_processor import RenameProcessor
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ComplementaryDataProcessor):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
"""
Ensures that the task description string ends with a newline character.
This processing step is required for compatibility with the PaliGemma tokenizer,
which expects a newline at the end of the text prompt. It handles both single
strings and lists of strings for the 'task' key in complementary data.
"""
def complementary_data(self, complementary_data):
"""
Adds a newline to the 'task' field if it doesn't already have one.
Args:
complementary_data: A dictionary that may contain a 'task' key with a
string or list of strings.
Returns:
A new dictionary with the modified 'task' field.
"""
if "task" not in complementary_data:
return complementary_data
@@ -64,13 +77,51 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor):
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
Args:
features: The input feature dictionary.
Returns:
The unchanged feature dictionary.
"""
return features
def make_pi0_pre_post_processors(
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
@@ -78,39 +129,39 @@ def make_pi0_pre_post_processors(
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessor(
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessor(device=config.device),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -16,55 +16,77 @@
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
def make_pi0fast_pre_post_processors(
config: PI0Config,
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0Fast policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+40 -18
View File
@@ -17,16 +17,16 @@
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
@@ -35,37 +35,59 @@ def make_sac_pre_post_processors(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the SAC policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the SAC policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -17,11 +17,11 @@ import torch
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.processor import (
DeviceProcessor,
IdentityProcessor,
NormalizerProcessor,
DeviceProcessorStep,
IdentityProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RobotProcessor,
)
@@ -30,30 +30,50 @@ def make_classifier_processor(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the reward classifier.
The pre-processing pipeline prepares input data for the classifier by:
1. Normalizing both input and output features based on dataset statistics.
2. Moving the data to the specified device.
The post-processing pipeline handles the classifier's output by:
1. Moving the data to the CPU.
2. Applying an identity step, as no unnormalization is needed for the output logits.
Args:
config: The configuration object for the RewardClassifier.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
NormalizerProcessor(
NormalizerProcessorStep(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
NormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessor(device=config.device),
DeviceProcessorStep(device=config.device),
]
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name="classifier_preprocessor",
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name="classifier_postprocessor",
**postprocessor_kwargs,
@@ -16,21 +16,20 @@
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
@@ -39,52 +38,82 @@ def make_smolvla_pre_post_processors(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the SmolVLA policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Ensuring the language task description ends with a newline character.
5. Tokenizing the language task description.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output actions to their original scale.
Args:
config: The configuration object for the SmolVLA policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
SmolVLANewLineProcessor(),
TokenizerProcessor(
TokenizerProcessorStep(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
padding_side="right",
max_length=config.tokenizer_max_length,
),
DeviceProcessor(device=config.device),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ComplementaryDataProcessor):
"""Add a new line to the end of the task if it doesn't have one."""
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
"""
A processor step that ensures the 'task' description ends with a newline character.
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
newline at the end of the prompt. It handles both single string tasks and lists
of string tasks.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
@@ -107,3 +136,8 @@ class SmolVLANewLineProcessor(ComplementaryDataProcessor):
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+40 -18
View File
@@ -16,16 +16,16 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
@@ -34,37 +34,59 @@ def make_tdmpc_pre_post_processors(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the TDMPC policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the TDMPC policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+40 -18
View File
@@ -17,16 +17,16 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
@@ -35,37 +35,59 @@ def make_vqbet_pre_post_processors(
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the VQ-BeT policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features, allowing customization to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the VQ-BeT policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
RobotProcessor(
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+72 -56
View File
@@ -14,74 +14,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .batch_processor import ToBatchProcessor
from .delta_action_processor import MapDeltaActionToRobotAction, MapTensorToDeltaActionDict
from .device_processor import DeviceProcessor
from .gym_action_processor import Numpy2TorchActionProcessor, Torch2NumpyActionProcessor
from .hil_processor import (
AddTeleopActionAsComplimentaryData,
AddTeleopEventsAsInfo,
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
RewardClassifierProcessor,
TimeLimitProcessor,
from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
create_transition,
merge_transitions,
transition_to_batch,
transition_to_dataset_frame,
)
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
from .observation_processor import VanillaObservationProcessor
from .core import EnvTransition, TransitionKey
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .device_processor import DeviceProcessorStep
from .gym_action_processor import Numpy2TorchActionProcessorStep, Torch2NumpyActionProcessorStep
from .hil_processor import (
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
GripperPenaltyProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
RewardClassifierProcessorStep,
TimeLimitProcessorStep,
)
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
from .observation_processor import VanillaObservationProcessorStep
from .pipeline import (
ActionProcessor,
DoneProcessor,
EnvTransition,
IdentityProcessor,
InfoProcessor,
ObservationProcessor,
ActionProcessorStep,
ComplementaryDataProcessorStep,
DataProcessorPipeline,
DoneProcessorStep,
IdentityProcessorStep,
InfoProcessorStep,
ObservationProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
ProcessorStep,
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionKey,
TruncatedProcessor,
RewardProcessorStep,
RobotProcessorPipeline,
TruncatedProcessorStep,
)
from .rename_processor import RenameProcessor
from .tokenizer_processor import TokenizerProcessor
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import TokenizerProcessorStep
__all__ = [
"ActionProcessor",
"AddTeleopActionAsComplimentaryData",
"AddTeleopEventsAsInfo",
"DeviceProcessor",
"DoneProcessor",
"MapDeltaActionToRobotAction",
"MapTensorToDeltaActionDict",
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",
"AddTeleopEventsAsInfoStep",
"ComplementaryDataProcessorStep",
"batch_to_transition",
"create_transition",
"DeviceProcessorStep",
"DoneProcessorStep",
"EnvTransition",
"GripperPenaltyProcessor",
"IdentityProcessor",
"ImageCropResizeProcessor",
"InfoProcessor",
"InterventionActionProcessor",
"JointVelocityProcessor",
"MapDeltaActionToRobotAction",
"MotorCurrentProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"GripperPenaltyProcessorStep",
"hotswap_stats",
"ObservationProcessor",
"IdentityProcessorStep",
"ImageCropResizeProcessorStep",
"InfoProcessorStep",
"InterventionActionProcessorStep",
"JointVelocityProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"merge_transitions",
"MotorCurrentProcessorStep",
"NormalizerProcessorStep",
"Numpy2TorchActionProcessorStep",
"ObservationProcessorStep",
"PolicyProcessorPipeline",
"ProcessorKwargs",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
"RewardClassifierProcessor",
"RewardProcessor",
"RobotProcessor",
"ToBatchProcessor",
"TokenizerProcessor",
"TimeLimitProcessor",
"Numpy2TorchActionProcessor",
"Torch2NumpyActionProcessor",
"RenameObservationsProcessorStep",
"RewardClassifierProcessorStep",
"RewardProcessorStep",
"DataProcessorPipeline",
"TimeLimitProcessorStep",
"AddBatchDimensionProcessorStep",
"RobotProcessorPipeline",
"TokenizerProcessorStep",
"Torch2NumpyActionProcessorStep",
"transition_to_batch",
"transition_to_dataset_frame",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
"TruncatedProcessorStep",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]
+159 -51
View File
@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,16 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script defines processor steps for adding a batch dimension to various components of an environment transition.
These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model.
"""
from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
EnvTransition,
ObservationProcessor,
from .core import EnvTransition
from .pipeline import (
ActionProcessorStep,
ComplementaryDataProcessorStep,
ObservationProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
)
@@ -28,22 +39,66 @@ from lerobot.processor.pipeline import (
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_action")
class ToBatchProcessorAction(ActionProcessor):
"""Process action component in-place, adding batch dimension if needed."""
class AddBatchDimensionActionStep(ActionProcessorStep):
"""
Processor step to add a batch dimension to a 1D tensor action.
def action(self, action):
This is useful for creating a batch of size 1 from a single action sample.
"""
def action(self, action: Tensor) -> Tensor:
"""
Adds a batch dimension to the action if it's a 1D tensor.
Args:
action: The action tensor.
Returns:
The action tensor with an added batch dimension.
"""
if not isinstance(action, Tensor) or action.dim() != 1:
return action
return action.unsqueeze(0)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
class ToBatchProcessorObservation(ObservationProcessor):
"""Process observation component in-place, adding batch dimensions where needed."""
class AddBatchDimensionObservationStep(ObservationProcessorStep):
"""
Processor step to add a batch dimension to observations.
def observation(self, observation):
It handles different types of observations:
- State vectors (1D tensors).
- Single images (3D tensors).
- Dictionaries of multiple images (3D tensors).
"""
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
"""
Adds a batch dimension to tensor-based observations in the observation dictionary.
Args:
observation: The observation dictionary.
Returns:
The observation dictionary with batch dimensions added to tensors.
"""
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
@@ -63,13 +118,44 @@ class ToBatchProcessorObservation(ObservationProcessor):
observation[key] = value.unsqueeze(0)
return observation
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
"""Process complementary data in-place, handling task field batching."""
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
"""
Processor step to add a batch dimension to complementary data fields.
def complementary_data(self, complementary_data):
Handles specific keys like 'task', 'index', and 'task_index' to make them batched.
- 'task' (str) is wrapped in a list.
- 'index' and 'task_index' (0D tensors) get a batch dimension.
"""
def complementary_data(self, complementary_data: dict) -> dict:
"""
Adds a batch dimension to specific fields in the complementary data dictionary.
Args:
complementary_data: The complementary data dictionary.
Returns:
The complementary data dictionary with batch dimensions added.
"""
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
@@ -89,54 +175,76 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
complementary_data["task_index"] = task_index_value.unsqueeze(0)
return complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class ToBatchProcessor(ProcessorStep):
"""Processor that adds batch dimensions to observations and actions when needed.
class AddBatchDimensionProcessorStep(ProcessorStep):
"""
A composite processor step that adds a batch dimension to the entire environment transition.
This processor ensures that observations and actions have proper batch dimensions for model processing:
This step combines individual processors for actions, observations, and complementary data
to create a batched transition (batch size 1) from a single-instance transition.
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
The processor only modifies tensors that need batching and leaves already
batched tensors unchanged.
Example:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
Attributes:
to_batch_action_processor: Processor for the action component.
to_batch_observation_processor: Processor for the observation component.
to_batch_complementary_data_processor: Processor for the complementary data component.
"""
to_batch_action_processor: ToBatchProcessorAction = field(default_factory=ToBatchProcessorAction)
to_batch_observation_processor: ToBatchProcessorObservation = field(
default_factory=ToBatchProcessorObservation
to_batch_action_processor: AddBatchDimensionActionStep = field(
default_factory=AddBatchDimensionActionStep
)
to_batch_complementary_data_processor: ToBatchProcessorComplementaryData = field(
default_factory=ToBatchProcessorComplementaryData
to_batch_observation_processor: AddBatchDimensionObservationStep = field(
default_factory=AddBatchDimensionObservationStep
)
to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
default_factory=AddBatchDimensionComplementaryDataStep
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies the batching process to all relevant parts of an environment transition.
Args:
transition: The environment transition to process.
Returns:
The environment transition with a batch dimension added.
"""
transition = self.to_batch_action_processor(transition)
transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
# NOTE: We ignore the batch dimension when transforming features
return features
+331 -134
View File
@@ -16,18 +16,17 @@
from __future__ import annotations
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from copy import deepcopy
from functools import singledispatch
from typing import Any
import numpy as np
import torch
from scipy.spatial.transform import Rotation
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from .pipeline import EnvTransition, TransitionKey
from .core import EnvTransition, TransitionKey
@singledispatch
@@ -44,12 +43,12 @@ def to_tensor(
different input types appropriately.
Args:
value: Input value to convert (tensor, array, scalar, sequence, etc.)
value: Input value to convert (tensor, array, scalar, sequence, etc.).
dtype: Target tensor dtype. If None, preserves original dtype.
device: Target device for the tensor.
Returns:
PyTorch tensor.
A PyTorch tensor.
Raises:
TypeError: If the input type is not supported.
@@ -59,7 +58,7 @@ def to_tensor(
@to_tensor.register(torch.Tensor)
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle existing PyTorch tensors."""
"""Handle conversion for existing PyTorch tensors."""
if dtype is not None:
value = value.to(dtype=dtype)
if device is not None:
@@ -75,17 +74,17 @@ def _(
device=None,
**kwargs,
) -> torch.Tensor:
"""Handle numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
"""Handle conversion for numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars.
if value.ndim == 0:
# Numpy scalars should be converted to 0-dimensional tensors
# Numpy scalars should be converted to 0-dimensional tensors.
scalar_value = value.item()
return torch.tensor(scalar_value, dtype=dtype, device=device)
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
# Create tensor from numpy array.
tensor = torch.from_numpy(value)
# Apply dtype conversion if specified
# Apply dtype and device conversion if specified.
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if device is not None:
@@ -99,20 +98,20 @@ def _(
@to_tensor.register(np.integer)
@to_tensor.register(np.floating)
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle scalar values including numpy scalars."""
"""Handle conversion for scalar values including numpy scalars."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(list)
@to_tensor.register(tuple)
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle sequences (lists, tuples)."""
"""Handle conversion for sequences (lists, tuples)."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(dict)
def _(value: dict, *, device=None, **kwargs) -> dict:
"""Handle dictionaries by recursively converting values to tensors."""
"""Handle conversion for dictionaries by recursively converting their values to tensors."""
if not value:
return {}
@@ -122,7 +121,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
continue
if isinstance(sub_value, dict):
# Recursively process nested dictionaries
# Recursively process nested dictionaries.
result[key] = to_tensor(
sub_value,
device=device,
@@ -130,7 +129,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
)
continue
# Convert individual values to tensors
# Convert individual values to tensors.
result[key] = to_tensor(
sub_value,
device=device,
@@ -139,17 +138,46 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
return result
def _from_tensor(x: Any):
def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
"""
Convert a PyTorch tensor to a numpy array or scalar if applicable.
If the input is not a tensor, it is returned unchanged.
Args:
x: The input, which can be a tensor or any other type.
Returns:
A numpy array, a scalar, or the original input.
"""
if isinstance(x, torch.Tensor):
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
return x
def _is_image(arr: Any) -> bool:
"""
Check if a given array is likely an image (uint8, 3D).
Args:
arr: The array to check.
Returns:
True if the array matches the image criteria, False otherwise.
"""
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Separate an observation dictionary into state and image components.
Args:
obs: The observation dictionary.
Returns:
A tuple containing two dictionaries: one for state and one for images.
"""
state, images = {}, {}
for k, v in obs.items():
if "image" in k.lower() or _is_image(v):
@@ -159,168 +187,337 @@ def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any],
return state, images
def make_obs_act_transition(
*, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None
) -> EnvTransition:
return {
TransitionKey.OBSERVATION: {} if obs is None else obs,
TransitionKey.ACTION: {} if act is None else act,
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {},
TransitionKey.REWARD: None,
TransitionKey.DONE: None,
TransitionKey.TRUNCATED: None,
}
# Private Helper Functions (Common Logic)
def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
Extract complementary data from a batch dictionary.
This includes padding flags, task description, and indices.
Args:
batch: The batch dictionary.
Returns:
A dictionary with the extracted complementary data.
"""
act_dict: dict[str, Any] = {}
for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)):
act_dict[f"{ACTION}.{k}"] = v
continue
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
return make_obs_act_transition(act=act_dict)
return {**pad_keys, **task_key, **index_key, **task_index_key}
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition:
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
"""
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
Merge two transitions, with the second one taking precedence in case of conflicts.
Args:
base: The base transition.
other: The transition to merge, which will overwrite base values.
Returns:
The merged transition dictionary.
"""
state, images = _split_obs_to_state_and_images(observation)
out = deepcopy(base)
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
for cam, img in images.items():
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
return make_obs_act_transition(obs=obs_dict)
def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
"""
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
"""
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
if action_dict is None:
return out
for k, v in action_dict.items():
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
out[out_key] = float(v)
for key in (
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
):
if other.get(key):
out.setdefault(key, {}).update(deepcopy(other[key]))
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
if k in other:
out[k] = other[k]
return out
def to_dataset_frame(
transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict]
) -> dict[str, any]:
# Core Conversion Functions
def create_transition(
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
reward: float = 0.0,
done: bool = False,
truncated: bool = False,
info: dict[str, Any] | None = None,
complementary_data: dict[str, Any] | None = None,
) -> EnvTransition:
"""
Converts a single EnvTransition or an iterable of them into a flat,
dataset-friendly dictionary for training or evaluation, according to
the provided `features` spec.
Create an `EnvTransition` dictionary with sensible defaults.
Args:
transitions_or_transition: Either a single EnvTransition dict
or an iterable of them (which will be merged).
features (dict[str, dict]):
A feature specification dictionary:
- 'action': dict with 'names': list of action feature names
- 'observation.state': dict with 'names': list of state feature names
- keys starting with 'observation.images.' are passed through
observation: Observation dictionary.
action: Action dictionary.
reward: Scalar reward value.
done: Episode termination flag.
truncated: Episode truncation flag.
info: Additional info dictionary.
complementary_data: Complementary data dictionary.
Returns:
batch (dict[str, any]): Flat dictionary containing:
- numpy arrays for "observation.state" and "action"
- any image tensors defined in features
- next.{reward,done,truncated}
- info dict
- *_is_pad flags and task from complementary_data
A complete `EnvTransition` dictionary.
"""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def action_to_transition(action: dict[str, Any]) -> EnvTransition:
"""
Convert a raw action dictionary into a standardized `EnvTransition`.
The keys in the action dictionary are prefixed with "action." and stored under
the `ACTION` key in the transition. Values are converted to tensors, except for
special types like `Rotation`.
Args:
action: The raw action dictionary from a teleoperation device or controller.
Returns:
An `EnvTransition` containing the formatted action.
"""
return create_transition(observation={}, action=action)
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
"""
Convert a raw robot observation dictionary into a standardized `EnvTransition`.
The observation is split into state and image components. State keys are prefixed
with "observation.state." and image keys with "observation.images.". The result is
stored under the `OBSERVATION` key in the transition.
Args:
observation: The raw observation dictionary from the environment.
Returns:
An `EnvTransition` containing the formatted observation.
"""
state, images = _split_obs_to_state_and_images(observation)
image_observations = {f"{OBS_IMAGES}.{cam}": img for cam, img in images.items()}
return create_transition(observation={**state, **image_observations}, action={})
def transition_to_action(transition: EnvTransition) -> dict[str, Any]:
"""
Extract a raw action dictionary for a robot from an `EnvTransition`.
This function searches for keys in the format "action.*.pos" or "action.*.vel"
and converts them into a flat dictionary suitable for sending to a robot controller.
Args:
transition: The `EnvTransition` containing the action.
Returns:
A dictionary representing the raw robot action.
"""
return transition.get(TransitionKey.ACTION)
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
"""
Merge a sequence of transitions into a single one.
If a single transition is provided, it is returned as is. For a sequence,
transitions are merged sequentially, with later transitions in the sequence
overwriting earlier ones.
Args:
transitions: A single transition or a sequence of them.
Returns:
A single merged `EnvTransition`.
Raises:
ValueError: If an empty sequence of transitions is provided.
"""
if not isinstance(transitions, Sequence): # Single transition
return transitions
items = list(transitions)
if not items:
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
result = items[0]
for t in items[1:]:
result = _merge_transitions(result, t)
return result
def transition_to_dataset_frame(
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
) -> dict[str, Any]:
"""
Convert one or more transitions into a flat dictionary suitable for a dataset frame.
This function processes `EnvTransition` objects according to a feature
specification, producing a format ready for training or evaluation.
Args:
transitions_or_transition: A single `EnvTransition` or a sequence to be merged.
features: A feature specification dictionary.
Returns:
A flat dictionary representing a single frame of data for a dataset.
"""
action_names = features.get(ACTION, {}).get("names", [])
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
out = deepcopy(base)
for key in (
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
):
if other.get(key):
out.setdefault(key, {}).update(deepcopy(other[key]))
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
if k in other:
out[k] = other[k]
return out
def _ensure_transition(obj) -> EnvTransition:
# single transition
if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj):
return obj
# iterable of transitions
if isinstance(obj, Iterable):
items = list(obj)
if not items:
return {}
acc = items[0]
for t in items[1:]:
acc = _merge(acc, t)
return acc
raise TypeError("Expected EnvTransition or iterable of them")
tr = _ensure_transition(transitions_or_transition)
tr = merge_transitions(transitions_or_transition)
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
act = tr.get(TransitionKey.ACTION, {}) or {}
batch: dict[str, any] = {}
batch: dict[str, Any] = {}
# Images passthrough
# Passthrough for images.
for k in image_keys:
if k in obs:
batch[k] = obs[k]
# Observation.state vector
# Create observation.state vector.
if obs_state_names:
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
vals = [from_tensor_to_numpy(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
# Action vector
# Create action vector.
if action_names:
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
vals = [from_tensor_to_numpy(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch[ACTION] = np.asarray(vals, dtype=np.float32)
# Add transition metadata.
if tr.get(TransitionKey.REWARD) is not None:
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
if tr.get(TransitionKey.DONE) is not None:
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
if tr.get(TransitionKey.TRUNCATED) is not None:
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
reward_val = from_tensor_to_numpy(tr[TransitionKey.REWARD])
# Check if features expect array format, otherwise keep as scalar.
if REWARD in features and features[REWARD].get("shape") == (1,):
batch[REWARD] = np.array([reward_val], dtype=np.float32)
else:
batch[REWARD] = reward_val
# Complementary data flags and task
if tr.get(TransitionKey.DONE) is not None:
done_val = from_tensor_to_numpy(tr[TransitionKey.DONE])
if DONE in features and features[DONE].get("shape") == (1,):
batch[DONE] = np.array([done_val], dtype=bool)
else:
batch[DONE] = done_val
if tr.get(TransitionKey.TRUNCATED) is not None:
truncated_val = from_tensor_to_numpy(tr[TransitionKey.TRUNCATED])
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
else:
batch[TRUNCATED] = truncated_val
# Add complementary data flags and task.
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if comp:
# pad flags
# Padding flags.
for k, v in comp.items():
if k.endswith("_is_pad"):
batch[k] = v
# task label
# Task label.
if comp.get("task") is not None:
batch["task"] = comp["task"]
return batch
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
"""
Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`.
This function maps recognized keys from a batch to the `EnvTransition` structure,
filling in missing keys with sensible defaults.
Args:
batch: A batch dictionary.
Returns:
An `EnvTransition` dictionary.
Raises:
ValueError: If the input is not a dictionary.
"""
# Validate input type.
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation and complementary data keys.
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
complementary_data = _extract_complementary_data(batch)
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get("action"),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
info=batch.get("info", {}),
complementary_data=complementary_data if complementary_data else None,
)
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"""
Convert an `EnvTransition` back to the canonical batch format used in LeRobot.
This is the inverse of `batch_to_transition`.
Args:
transition: The `EnvTransition` to convert.
Returns:
A batch dictionary with canonical LeRobot field names.
"""
batch = {
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Add complementary data.
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if comp_data:
batch.update(comp_data)
# Flatten observation dictionary.
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
return batch
def identity_transition(tr: EnvTransition) -> EnvTransition:
"""
An identity function for transitions, returning the input unchanged.
Useful as a default or placeholder in processing pipelines.
Args:
tr: An `EnvTransition`.
Returns:
The same `EnvTransition`.
"""
return tr
+49
View File
@@ -0,0 +1,49 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import Enum
from typing import Any, TypedDict
import torch
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
# TODO(Steven): Use consts
OBSERVATION = "observation"
ACTION = "action"
REWARD = "reward"
DONE = "done"
TRUNCATED = "truncated"
INFO = "info"
COMPLEMENTARY_DATA = "complementary_data"
EnvTransition = TypedDict(
"EnvTransition",
{
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
TransitionKey.ACTION.value: Any | torch.Tensor | None,
TransitionKey.REWARD.value: float | torch.Tensor | None,
TransitionKey.DONE.value: bool | torch.Tensor | None,
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
TransitionKey.INFO.value: dict[str, Any] | None,
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
},
)
+71 -59
View File
@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
@@ -18,60 +18,70 @@ from dataclasses import dataclass
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@dataclass
class MapTensorToDeltaActionDict(ActionProcessor):
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
"""
Map a tensor to a delta action dictionary.
Maps a flat action tensor from a policy to a structured delta action dictionary.
This step is typically used after a policy outputs a continuous action vector.
It decomposes the vector into named components for delta movements of the
end-effector (x, y, z) and optionally the gripper.
Attributes:
use_gripper: If True, assumes the 4th element of the tensor is the
gripper action.
"""
use_gripper: bool = True
def action(self, action: Tensor) -> dict:
if isinstance(action, dict):
return action
if action.dim() > 1:
action = action.squeeze(0)
# TODO (maractingi): add rotation
delta_action = {
"action.delta_x": action[0],
"action.delta_y": action[1],
"action.delta_z": action[2],
"delta_x": action[0],
"delta_y": action[1],
"delta_z": action[2],
}
if action.shape[0] > 3:
delta_action["action.gripper"] = action[3]
if self.use_gripper:
delta_action["gripper"] = action[3]
return delta_action
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION]["delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
if self.use_gripper:
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass
class MapDeltaActionToRobotAction(ActionProcessor):
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
"""
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
for use with inverse kinematics processors.
Maps delta actions from teleoperators to robot target actions for inverse kinematics.
Expected input ACTION keys:
{
"action.delta_x": float,
"action.delta_y": float,
"action.delta_z": float,
"action.gripper": float (optional),
}
This step converts a dictionary of delta movements (e.g., from a gamepad)
into a target action format that includes an "enabled" flag and target
end-effector positions. It also handles scaling and noise filtering.
Output ACTION keys:
{
"action.enabled": bool,
"action.target_x": float,
"action.target_y": float,
"action.target_z": float,
"action.target_wx": float,
"action.target_wy": float,
"action.target_wz": float,
"action.gripper": float,
}
Attributes:
position_scale: A factor to scale the delta position inputs.
rotation_scale: A factor to scale the delta rotation inputs (currently unused).
noise_threshold: The magnitude below which delta inputs are considered noise
and do not trigger an "enabled" state.
"""
# Scale factors for delta movements
@@ -82,10 +92,10 @@ class MapDeltaActionToRobotAction(ActionProcessor):
def action(self, action: dict) -> dict:
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
delta_x = action.pop("action.delta_x", 0.0)
delta_y = action.pop("action.delta_y", 0.0)
delta_z = action.pop("action.delta_z", 0.0)
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
delta_x = action.pop("delta_x", 0.0)
delta_y = action.pop("delta_y", 0.0)
delta_z = action.pop("delta_z", 0.0)
gripper = action.pop("gripper", 1.0) # Default to "stay" (1.0)
# Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected
@@ -105,31 +115,33 @@ class MapDeltaActionToRobotAction(ActionProcessor):
# Update action with robot target format
action = {
"action.enabled": enabled,
"action.target_x": scaled_delta_x,
"action.target_y": scaled_delta_y,
"action.target_z": scaled_delta_z,
"action.target_wx": target_wx,
"action.target_wy": target_wy,
"action.target_wz": target_wz,
"action.gripper": float(gripper),
"enabled": enabled,
"target_x": scaled_delta_x,
"target_y": scaled_delta_y,
"target_z": scaled_delta_z,
"target_wx": target_wx,
"target_wy": target_wy,
"target_wz": target_wz,
"gripper": float(gripper),
}
return action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Transform features to match output format."""
# Update features to reflect the new action format
features.update(
{
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
}
)
features[PipelineFeatureType.ACTION].pop("delta_x", None)
features[PipelineFeatureType.ACTION].pop("delta_y", None)
features[PipelineFeatureType.ACTION].pop("delta_z", None)
features[PipelineFeatureType.ACTION].pop("gripper", None)
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
+81 -23
View File
@@ -13,24 +13,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script defines a processor step for moving environment transition data to a specific torch device and casting
its floating-point precision.
"""
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.utils import get_safe_torch_device
from .core import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor(ProcessorStep):
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
class DeviceProcessorStep(ProcessorStep):
"""
Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their
floating-point data type.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
This is crucial for preparing data for model training or inference on hardware like GPUs.
Attributes:
device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0").
float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16").
If None, the dtype is not changed.
"""
device: str = "cpu"
@@ -47,8 +60,15 @@ class DeviceProcessor(ProcessorStep):
}
def __post_init__(self):
self._device: torch.device = get_safe_torch_device(self.device)
self.device = self._device.type # cuda might have changed to cuda:1
"""
Initializes the processor by converting string configurations to torch objects.
This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the
`float_dtype` string, converting it to a `torch.dtype` object.
"""
self.tensor_device: torch.device = get_safe_torch_device(self.device)
# Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0")
self.device = self.tensor_device.type
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
@@ -57,28 +77,33 @@ class DeviceProcessor(ProcessorStep):
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
)
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype.
"""
Moves a single tensor to the target device and casts its dtype.
If the tensor is already on a GPU and we're configured for a GPU, it preserves
that GPU placement (useful for multi-GPU training with Accelerate).
Otherwise, it moves to the configured device.
Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than
the target, which is useful when using frameworks like Accelerate.
Args:
tensor: The input torch.Tensor.
Returns:
The processed tensor on the correct device and with the correct dtype.
"""
# Determine target device
if tensor.is_cuda and self._device.type == "cuda":
# Both tensor and target are on GPU - preserve tensor's GPU placement
if tensor.is_cuda and self.tensor_device.type == "cuda":
# Both tensor and target are on GPU - preserve tensor's GPU placement.
# This handles multi-GPU scenarios where Accelerate has already placed
# tensors on the correct GPU for each process
# tensors on the correct GPU for each process.
target_device = tensor.device
else:
# Either tensor is on CPU, or we're configured for CPU
# In both cases, use the configured device
target_device = self._device
# Either tensor is on CPU, or we're configured for CPU.
# In both cases, use the configured device.
target_device = self.tensor_device
# Only move if necessary
if tensor.device != target_device:
@@ -91,6 +116,18 @@ class DeviceProcessor(ProcessorStep):
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies device and dtype conversion to all tensors in an environment transition.
It iterates through the transition, finds all `torch.Tensor` objects (including those nested in
dictionaries like `observation`), and processes them.
Args:
transition: The input `EnvTransition` object.
Returns:
A new `EnvTransition` object with all tensors moved to the target device and dtype.
"""
new_transition = transition.copy()
simple_tensor_keys = [
@@ -105,13 +142,13 @@ class DeviceProcessor(ProcessorStep):
TransitionKey.COMPLEMENTARY_DATA,
]
# Process simple tensors
# Process simple, top-level tensors
for key in simple_tensor_keys:
value = transition.get(key)
if isinstance(value, torch.Tensor):
new_transition[key] = self._process_tensor(value)
# Process dictionary-like tensors
# Process tensors nested within dictionaries
for key in dict_tensor_keys:
data_dict = transition.get(key)
if data_dict is not None:
@@ -124,5 +161,26 @@ class DeviceProcessor(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
"""
Returns the serializable configuration of the processor.
Returns:
A dictionary containing the device and float_dtype settings.
"""
return {"device": self.device, "float_dtype": self.float_dtype}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Returns the input features unchanged.
Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape).
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
+40 -9
View File
@@ -1,4 +1,4 @@
#! /usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
@@ -10,20 +10,35 @@
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import numpy as np
import torch
from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import to_tensor
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessor(ActionProcessor):
"""Convert PyTorch tensor actions to NumPy arrays."""
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
"""
Converts a PyTorch tensor action to a NumPy array.
This step is useful when the output of a policy (typically a torch.Tensor)
needs to be passed to an environment or component that expects a NumPy array.
Attributes:
squeeze_batch_dim: If True, removes the first dimension of the array
if it is of size 1. This is useful for converting a
batched action of size (1, D) to a single action of size (D,).
"""
squeeze_batch_dim: bool = True
@@ -36,8 +51,8 @@ class Torch2NumpyActionProcessor(ActionProcessor):
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
# Remove batch dimensions but preserve action dimensions.
# Only squeeze if there's a batch dimension (first dim == 1).
if (
self.squeeze_batch_dim
and numpy_action.shape
@@ -48,11 +63,22 @@ class Torch2NumpyActionProcessor(ActionProcessor):
return numpy_action
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessor(ActionProcessor):
"""Convert NumPy array action to PyTorch tensor."""
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
"""
Converts a NumPy array action to a PyTorch tensor.
This step is useful for converting actions from environments or hardware,
which are often NumPy arrays, into PyTorch tensors that can be processed
by a policy or model.
"""
def action(self, action: np.ndarray) -> torch.Tensor:
if not isinstance(action, np.ndarray):
@@ -62,3 +88,8 @@ class Numpy2TorchActionProcessor(ActionProcessor):
)
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
return torch_action
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+321 -43
View File
@@ -1,68 +1,203 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
from dataclasses import dataclass
from typing import Any
from typing import Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import ACTION
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
EnvTransition,
InfoProcessor,
ObservationProcessor,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
TruncatedProcessor,
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from .core import EnvTransition, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
InfoProcessorStep,
ObservationProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
TruncatedProcessorStep,
)
GRIPPER_KEY = "gripper"
DISCRETE_PENALTY_KEY = "discrete_penalty"
TELEOP_ACTION_KEY = "teleop_action"
@runtime_checkable
class HasTeleopEvents(Protocol):
"""
Minimal protocol for objects that provide teleoperation events.
This protocol defines the `get_teleop_events()` method, allowing processor
steps to interact with teleoperators that support event-based controls
(like episode termination or success flagging) without needing to know the
teleoperator's specific class.
"""
def get_teleop_events(self) -> dict[str, Any]:
"""
Get extra control events from the teleoperator.
Returns:
A dictionary containing control events such as:
- `is_intervention`: bool - Whether the human is currently intervening.
- `terminate_episode`: bool - Whether to terminate the current episode.
- `success`: bool - Whether the episode was successful.
- `rerecord_episode`: bool - Whether to rerecord the episode.
"""
...
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
def _check_teleop_with_events(teleop: Teleoperator) -> None:
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
Args:
teleop: The teleoperator instance to check.
Raises:
TypeError: If the teleoperator does not have a `get_teleop_events` method.
"""
if not isinstance(teleop, HasTeleopEvents):
raise TypeError(
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
)
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
"""Add teleoperator action to transition complementary data."""
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
"""
Adds the raw action from a teleoperator to the transition's complementary data.
This is useful for human-in-the-loop scenarios where the human's input needs to
be available to downstream processors, for example, to override a policy's action
during an intervention.
Attributes:
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict) -> dict:
"""
Retrieves the teleoperator's action and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data dictionary.
Returns:
A new dictionary with the teleoperator action added under the
`teleop_action` key.
"""
new_complementary_data = dict(complementary_data)
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfo(InfoProcessor):
"""Add teleoperator control events to transition info."""
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
"""
Adds teleoperator control events (e.g., terminate, success) to the transition's info.
teleop_device: Teleoperator
This step extracts control events from teleoperators that support event-based
interaction, making these signals available to other parts of the system.
Attributes:
teleop_device: An instance of a teleoperator that implements the
`HasTeleopEvents` protocol.
"""
teleop_device: TeleopWithEvents
def __post_init__(self):
"""Validates that the provided teleoperator supports events after initialization."""
_check_teleop_with_events(self.teleop_device)
def info(self, info: dict) -> dict:
"""
Retrieves teleoperator events and updates the info dictionary.
Args:
info: The incoming info dictionary.
Returns:
A new dictionary including the teleoperator events.
"""
new_info = dict(info)
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
teleop_events = self.teleop_device.get_teleop_events()
new_info.update(teleop_events)
return new_info
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
class ImageCropResizeProcessor(ObservationProcessor):
"""Crop and resize image observations."""
class ImageCropResizeProcessorStep(ObservationProcessorStep):
"""
Crops and/or resizes image observations.
This step iterates through all image keys in an observation dictionary and applies
the specified transformations. It handles device placement, moving tensors to the
CPU if necessary for operations not supported on certain accelerators like MPS.
Attributes:
crop_params_dict: A dictionary mapping image keys to cropping parameters
(top, left, height, width).
resize_size: A tuple (height, width) to resize all images to.
"""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict) -> dict:
"""
Applies cropping and resizing to all images in the observation dictionary.
Args:
observation: The observation dictionary, potentially containing image tensors.
Returns:
A new observation dictionary with transformed images.
"""
if self.resize_size is None and not self.crop_params_dict:
return observation
@@ -90,29 +225,65 @@ class ImageCropResizeProcessor(ObservationProcessor):
return new_observation
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary with the crop parameters and resize dimensions.
"""
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates the image feature shapes in the policy features dictionary if resizing is applied.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary with new image shapes.
"""
if self.resize_size is None:
return features
for key in features:
for key in features[PipelineFeatureType.OBSERVATION]:
if "image" in key:
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
nb_channel = features[PipelineFeatureType.OBSERVATION][key].shape[0]
features[PipelineFeatureType.OBSERVATION][key] = PolicyFeature(
type=features[PipelineFeatureType.OBSERVATION][key].type,
shape=(nb_channel, *self.resize_size),
)
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor(TruncatedProcessor):
"""Track episode steps and enforce time limits."""
class TimeLimitProcessorStep(TruncatedProcessorStep):
"""
Tracks episode steps and enforces a time limit by truncating the episode.
Attributes:
max_episode_steps: The maximum number of steps allowed per episode.
current_step: The current step count for the active episode.
"""
max_episode_steps: int
current_step: int = 0
def truncated(self, truncated):
def truncated(self, truncated: bool) -> bool:
"""
Increments the step counter and sets the truncated flag if the time limit is reached.
Args:
truncated: The incoming truncated flag.
Returns:
True if the episode step limit is reached, otherwise the incoming value.
"""
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
@@ -120,24 +291,54 @@ class TimeLimitProcessor(TruncatedProcessor):
return truncated
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the `max_episode_steps`.
"""
return {
"max_episode_steps": self.max_episode_steps,
}
def reset(self) -> None:
"""Resets the step counter, typically called at the start of a new episode."""
self.current_step = 0
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor(ComplementaryDataProcessor):
"""Apply penalty for inappropriate gripper usage."""
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
This step penalizes actions that attempt to close an already closed gripper or
open an already open one, based on position thresholds.
Attributes:
penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization.
"""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data):
"""Calculate gripper penalty and add to complementary data."""
def complementary_data(self, complementary_data: dict) -> dict:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
"""
action = self.transition.get(TransitionKey.ACTION)
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
@@ -164,25 +365,57 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
return new_complementary_data
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the penalty value and max gripper position.
"""
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
"""Resets the processor's internal state."""
pass
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor(ProcessorStep):
"""Handle human intervention actions and episode termination."""
class InterventionActionProcessorStep(ProcessorStep):
"""
Handles human intervention, overriding policy actions and managing episode termination.
When an intervention is detected (via teleoperator events in the `info` dict),
this step replaces the policy's action with the human's teleoperated action.
It also processes signals to terminate the episode or flag success.
Attributes:
use_gripper: Whether to include the gripper in the teleoperated action.
terminate_on_success: If True, automatically sets the `done` flag when a
`success` event is received.
"""
use_gripper: bool = False
terminate_on_success: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Processes the transition to handle interventions.
Args:
transition: The incoming environment transition.
Returns:
The modified transition, potentially with an overridden action, updated
reward, and termination status.
"""
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
@@ -238,16 +471,40 @@ class InterventionActionProcessor(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the step's configuration attributes.
"""
return {
"use_gripper": self.use_gripper,
"terminate_on_success": self.terminate_on_success,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessor(ProcessorStep):
"""Apply reward classification to image observations."""
class RewardClassifierProcessorStep(ProcessorStep):
"""
Applies a pretrained reward classifier to image observations to predict success.
This step uses a model to determine if the current state is successful, updating
the reward and potentially terminating the episode.
Attributes:
pretrained_path: Path to the pretrained reward classifier model.
device: The device to run the classifier on.
success_threshold: The probability threshold to consider a prediction as successful.
success_reward: The reward value to assign on success.
terminate_on_success: If True, terminates the episode upon successful classification.
reward_classifier: The loaded classifier model instance.
"""
pretrained_path: str | None = None
device: str = "cpu"
@@ -258,7 +515,7 @@ class RewardClassifierProcessor(ProcessorStep):
reward_classifier: Any = None
def __post_init__(self):
"""Initialize the reward classifier after dataclass initialization."""
"""Initializes the reward classifier model after the dataclass is created."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -267,15 +524,26 @@ class RewardClassifierProcessor(ProcessorStep):
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
"""
Processes a transition, applying the reward classifier to its image observations.
Args:
transition: The incoming environment transition.
Returns:
The modified transition with an updated reward and done flag based on the
classifier's prediction.
"""
new_transition = transition.copy()
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
return transition
return new_transition
# Extract images from observation
images = {key: value for key, value in observation.items() if "image" in key}
if not images:
return transition
return new_transition
# Run reward classifier
start_time = time.perf_counter()
@@ -285,8 +553,8 @@ class RewardClassifierProcessor(ProcessorStep):
classifier_frequency = 1 / (time.perf_counter() - start_time)
# Calculate reward and termination
reward = transition.get(TransitionKey.REWARD, 0.0)
terminated = transition.get(TransitionKey.DONE, False)
reward = new_transition.get(TransitionKey.REWARD, 0.0)
terminated = new_transition.get(TransitionKey.DONE, False)
if math.isclose(success, 1, abs_tol=1e-2):
reward = self.success_reward
@@ -294,7 +562,6 @@ class RewardClassifierProcessor(ProcessorStep):
terminated = True
# Update transition
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward
new_transition[TransitionKey.DONE] = terminated
@@ -306,9 +573,20 @@ class RewardClassifierProcessor(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the step's configuration attributes.
"""
return {
"device": self.device,
"success_threshold": self.success_threshold,
"success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@@ -1,11 +1,28 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_STATE
from lerobot.processor.pipeline import (
ObservationProcessor,
ObservationProcessorStep,
ProcessorStepRegistry,
)
from lerobot.robots import Robot
@@ -13,19 +30,44 @@ from lerobot.robots import Robot
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor(ObservationProcessor):
"""Add joint velocity information to observations."""
class JointVelocityProcessorStep(ObservationProcessorStep):
"""
Calculates and appends joint velocity information to the observation state.
This step computes the velocity of each joint by calculating the finite
difference between the current and the last observed joint positions. The
resulting velocity vector is then concatenated to the original state vector.
Attributes:
dt: The time step (delta time) in seconds between observations, used for
calculating velocity.
last_joint_positions: Stores the joint positions from the previous step
to enable velocity calculation.
"""
dt: float = 0.1
last_joint_positions: torch.Tensor | None = None
def observation(self, observation: dict) -> dict:
"""
Computes joint velocities and adds them to the observation state.
Args:
observation: The input observation dictionary, expected to contain
an `observation.state` key with joint positions.
Returns:
A new observation dictionary with the `observation.state` tensor
extended to include joint velocities.
Raises:
ValueError: If `observation.state` is not found in the observation.
"""
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
current_positions = observation.get(OBS_STATE)
if current_positions is None:
# TODO(steven): if we get here, then the transform_features method will not hold
return observation
raise ValueError(f"{OBS_STATE} is not in observation")
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
@@ -42,46 +84,92 @@ class JointVelocityProcessor(ObservationProcessor):
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
new_observation[OBS_STATE] = extended_state
return new_observation
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the time step `dt`.
"""
return {
"dt": self.dt,
}
def reset(self) -> None:
"""Resets the internal state, clearing the last known joint positions."""
self.last_joint_positions = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features:
original_feature = features["observation.state"]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates the `observation.state` feature to reflect the added velocities.
This method doubles the size of the first dimension of the `observation.state`
shape to account for the concatenation of position and velocity vectors.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary.
"""
if OBS_STATE in features[PipelineFeatureType.OBSERVATION]:
original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
# Double the shape to account for positions + velocities
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
type=original_feature.type, shape=new_shape
)
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessor(ObservationProcessor):
"""Add motor current information to observations."""
class MotorCurrentProcessorStep(ObservationProcessorStep):
"""
Reads motor currents from a robot and appends them to the observation state.
This step queries the robot's hardware interface to get the present current
for each motor and concatenates this information to the existing state vector.
Attributes:
robot: An instance of a `lerobot` Robot class that provides access to
the hardware bus.
"""
robot: Robot | None = None
def observation(self, observation: dict) -> dict:
"""
Fetches motor currents and adds them to the observation state.
Args:
observation: The input observation dictionary.
Returns:
A new observation dictionary with the `observation.state` tensor
extended to include motor currents.
Raises:
ValueError: If the `robot` attribute has not been set.
"""
# Get current values from robot state
if self.robot is None:
return observation
raise ValueError("Robot is not set")
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get("observation.state")
current_state = observation.get(OBS_STATE)
if current_state is None:
return observation
@@ -89,15 +177,27 @@ class MotorCurrentProcessor(ObservationProcessor):
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
new_observation[OBS_STATE] = extended_state
return new_observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if "observation.state" in features and self.robot is not None:
from lerobot.configs.types import PolicyFeature
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates the `observation.state` feature to reflect the added motor currents.
original_feature = features["observation.state"]
This method increases the size of the first dimension of the `observation.state`
shape by the number of motors in the robot.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary.
"""
if OBS_STATE in features[PipelineFeatureType.OBSERVATION] and self.robot is not None:
original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
# Add motor current dimensions to the original state shape
num_motors = 0
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
@@ -105,5 +205,7 @@ class MotorCurrentProcessor(ObservationProcessor):
if num_motors > 0:
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
type=original_feature.type, shape=new_shape
)
return features
@@ -15,16 +15,22 @@
# limitations under the License.
"""
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
A generic script to migrate LeRobot policies with built-in normalization layers to the new
pipeline-based processor system.
This script:
1. Loads an existing pretrained policy model
2. Extracts normalization statistics from the model
3. Creates both preprocessor and postprocessor:
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
- Postprocessor: unnormalizes outputs (actions) for inference
4. Removes normalization layers from the model state_dict
5. Saves the new model and both processors
This script performs the following steps:
1. Loads a pretrained policy model and its configuration from a local path or the
Hugging Face Hub.
2. Scans the model's state dictionary to extract normalization statistics (e.g., mean,
std, min, max) for all features.
3. Creates two new processor pipelines:
- A preprocessor that normalizes inputs (observations) and outputs (actions).
- A postprocessor that unnormalizes outputs (actions) for inference.
4. Removes the original normalization layers from the model's state dictionary,
creating a "clean" model.
5. Saves the new clean model, the preprocessor, the postprocessor, and a generated
model card to a new directory.
6. Optionally pushes all the new artifacts to the Hugging Face Hub.
Usage:
python src/lerobot/processor/migrate_policy_normalization.py \
@@ -46,11 +52,12 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.device_processor import DeviceProcessor
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor.rename_processor import RenameProcessor
from .batch_processor import AddBatchDimensionProcessorStep
from .device_processor import DeviceProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from .pipeline import PolicyProcessorPipeline
from .rename_processor import RenameObservationsProcessorStep
# Policy type to class mapping
POLICY_CLASSES = {
@@ -67,7 +74,21 @@ POLICY_CLASSES = {
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Extract normalization statistics from model state_dict."""
"""
Scans a model's state_dict to find and extract normalization statistics.
This function identifies keys corresponding to normalization layers (e.g., those
for mean, std, min, max) based on a set of predefined patterns and organizes
them into a nested dictionary.
Args:
state_dict: The state dictionary of a pretrained policy model.
Returns:
A nested dictionary where outer keys are feature names (e.g.,
'observation.state') and inner keys are statistic types ('mean', 'std'),
mapping to their corresponding tensor values.
"""
stats = {}
# Define patterns to match and their prefixes to remove
@@ -111,7 +132,25 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
def detect_features_and_norm_modes(
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
"""Detect features and normalization modes from config and stats."""
"""
Infers policy features and normalization modes from the model config and stats.
This function first attempts to find feature definitions and normalization
mappings directly from the policy's configuration file. If this information is
not present, it infers it from the extracted normalization statistics, using
tensor shapes to determine feature shapes and the presence of specific stat
keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode.
It applies sensible defaults if inference is not possible.
Args:
config: The policy's configuration dictionary from `config.json`.
stats: The normalization statistics extracted from the model's state_dict.
Returns:
A tuple containing:
- A dictionary mapping feature names to `PolicyFeature` objects.
- A dictionary mapping `FeatureType` enums to `NormalizationMode` enums.
"""
features = {}
norm_modes = {}
@@ -203,7 +242,19 @@ def detect_features_and_norm_modes(
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Remove normalization layers from state_dict."""
"""
Creates a new state_dict with all normalization-related layers removed.
This function filters the original state dictionary, excluding any keys that
match a set of predefined patterns associated with normalization modules.
Args:
state_dict: The original model state dictionary.
Returns:
A new state dictionary containing only the core model weights, without
any normalization parameters.
"""
new_state_dict = {}
# Patterns to remove
@@ -227,7 +278,16 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert features from old format to PolicyFeature objects."""
"""
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
Args:
features_dict: The feature dictionary in the old format, where values are
simple dictionaries (e.g., `{"shape": [7]}`).
Returns:
A dictionary mapping feature names to `PolicyFeature` dataclass objects.
"""
converted_features = {}
for key, feature_dict in features_dict.items():
@@ -253,8 +313,18 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
def load_model_from_hub(
repo_id: str, revision: str = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
"""
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
Args:
repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha').
revision: The specific git revision (branch, tag, or commit hash) to use.
Returns:
A tuple containing the model's state dictionary, the policy configuration,
and the training configuration.
"""
# Download files.
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
@@ -403,8 +473,8 @@ def main():
# Now create preprocessor and postprocessor with cleaned_config available
print("Creating preprocessor and postprocessor...")
# The pattern from existing processor factories:
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessor for output_features only
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
@@ -412,23 +482,23 @@ def main():
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
RenameObservationsProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
ToBatchProcessor(),
DeviceProcessor(device=policy_config.device),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=policy_config.device),
]
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub:
+226 -54
View File
@@ -1,3 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
@@ -7,16 +24,12 @@ from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import (
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
RobotProcessor,
TransitionKey,
)
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
@dataclass
@@ -24,22 +37,48 @@ class _NormalizationMixin:
"""
A mixin class providing core functionality for normalization and unnormalization.
This class manages normalization statistics, their conversion to tensors, device placement,
and the application of normalization transformations. It is designed to be inherited by
concrete ProcessorStep implementations.
This class manages normalization statistics (`stats`), converts them to tensors for
efficient computation, handles device placement, and implements the logic for
applying normalization transformations (mean/std and min/max). It is designed to
be inherited by concrete `ProcessorStep` implementations and should not be used
directly.
Attributes:
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
the data structure to be processed.
norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying
which normalization method to use for each type of feature.
stats: A dictionary containing the normalization statistics (e.g., mean, std,
min, max) for each feature.
device: The PyTorch device on which to store and perform tensor operations.
eps: A small epsilon value to prevent division by zero in normalization
calculations.
normalize_observation_keys: An optional set of keys to selectively apply
normalization to specific observation features.
_tensor_stats: An internal dictionary holding the normalization statistics as
PyTorch tensors.
"""
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
device: torch.device | str | None = None
dtype: torch.dtype | None = None
eps: float = 1e-8
normalize_observation_keys: set[str] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
def __post_init__(self):
# Robust JSON deserialization handling (guard empty maps)
"""
Initializes the mixin after dataclass construction.
This method handles the robust deserialization of `features` and `norm_map`
from JSON-compatible formats (where enums become strings and tuples become
lists) and converts the provided `stats` dictionary into a dictionary of
tensors (`_tensor_stats`) on the specified device.
"""
# Robust JSON deserialization handling (guard empty maps).
if self.features:
first_val = next(iter(self.features.values()))
if isinstance(first_val, dict):
@@ -60,15 +99,40 @@ class _NormalizationMixin:
# Convert stats to tensors and move to the target device once during initialization.
self.stats = self.stats or {}
self._tensor_stats = to_tensor(self.stats, device=self.device)
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
def to(self, device: torch.device | str) -> _NormalizationMixin:
"""Moves the processor's normalization stats to the specified device and returns self."""
self.device = device
self._tensor_stats = to_tensor(self.stats, device=self.device)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
) -> _NormalizationMixin:
"""
Moves the processor's normalization stats to the specified device.
Args:
device: The target PyTorch device.
Returns:
The instance of the class, allowing for method chaining.
"""
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
return self
def state_dict(self) -> dict[str, Tensor]:
"""
Returns the normalization statistics as a flat state dictionary.
All tensors are moved to the CPU before being returned, which is standard practice
for saving state dictionaries.
Returns:
A flat dictionary mapping from `'feature_name.stat_name'` to the
corresponding statistics tensor on the CPU.
"""
flat: dict[str, Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
@@ -76,6 +140,15 @@ class _NormalizationMixin:
return flat
def load_state_dict(self, state: dict[str, Tensor]) -> None:
"""
Loads normalization statistics from a state dictionary.
The loaded tensors are moved to the processor's configured device.
Args:
state: A flat state dictionary with keys in the format
`'feature_name.stat_name'`.
"""
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
@@ -84,7 +157,26 @@ class _NormalizationMixin:
dtype=torch.float32, device=self.device
)
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
self.stats = {}
for key, tensor_dict in self._tensor_stats.items():
self.stats[key] = {}
for stat_name, tensor in tensor_dict.items():
# Convert tensor back to python/numpy format
self.stats[key][stat_name] = from_tensor_to_numpy(tensor)
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
This method is used when saving the processor to disk, ensuring that its
configuration can be reconstructed later.
Returns:
A JSON-serializable dictionary containing the configuration.
"""
config = {
"eps": self.eps,
"features": {
@@ -97,24 +189,63 @@ class _NormalizationMixin:
return config
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
"""
Applies (un)normalization to all relevant features in an observation dictionary.
Args:
observation: The observation dictionary to process.
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
Returns:
A new observation dictionary with the transformed tensor values.
"""
new_observation = dict(observation)
for key, feature in self.features.items():
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
continue
if feature.type != FeatureType.ACTION and key in new_observation:
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
# Convert to tensor but preserve original dtype for adaptation logic
tensor = torch.as_tensor(new_observation[key])
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
return new_observation
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
tensor = torch.as_tensor(action, dtype=torch.float32)
# Convert to tensor but preserve original dtype for adaptation logic
"""
Applies (un)normalization to an action tensor.
Args:
action: The action tensor to process.
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
Returns:
The transformed action tensor.
"""
tensor = torch.as_tensor(action)
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
return processed_action
def _apply_transform(
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
) -> Tensor:
"""Core logic to apply normalization or unnormalization."""
"""
Core logic to apply a normalization or unnormalization transformation to a tensor.
This method selects the appropriate normalization mode (e.g., mean/std, min/max)
based on the feature type and applies the corresponding mathematical operation.
Args:
tensor: The input tensor to transform.
key: The feature key corresponding to the tensor.
feature_type: The `FeatureType` of the tensor.
inverse: If `True`, applies the inverse transformation (unnormalization).
Returns:
The transformed tensor.
Raises:
ValueError: If an unsupported normalization mode is encountered.
"""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
return tensor
@@ -122,19 +253,13 @@ class _NormalizationMixin:
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# Ensure input tensor is on the same device as the stats.
if self.device and tensor.device != self.device:
tensor = tensor.to(self.device)
# For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor
if self._tensor_stats and key in self._tensor_stats:
first_stat = next(iter(self._tensor_stats[key].values()))
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
self.to(device=tensor.device, dtype=tensor.dtype)
# For Accelerate compatibility: move stats to match input tensor device
input_device = tensor.device
stats = self._tensor_stats[key]
tensor = tensor.to(dtype=torch.float32)
# Move stats to input device if needed
stats_device = next(iter(stats.values())).device
if stats_device != input_device:
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
@@ -151,7 +276,7 @@ class _NormalizationMixin:
# to prevent division by zero. This consistently maps an input equal to
# min_val to -1, ensuring a stable transformation.
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
)
if inverse:
# Map from [-1, 1] back to [min, max]
@@ -165,13 +290,13 @@ class _NormalizationMixin:
@dataclass
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
"""
A processor that applies normalization to observations and actions in a transition.
A processor step that applies normalization to observations and actions in a transition.
This class directly implements the normalization logic for both observation and action
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
initialization.
This class uses the logic from `_NormalizationMixin` to perform forward normalization
(e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]).
It is typically used in the pre-processing pipeline before feeding data to a policy.
"""
@classmethod
@@ -184,7 +309,21 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
normalize_observation_keys: set[str] | None = None,
eps: float = 1e-8,
device: torch.device | str | None = None,
) -> NormalizerProcessor:
) -> NormalizerProcessorStep:
"""
Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`.
Args:
dataset: The dataset from which to extract normalization statistics.
features: The feature definition for the processor.
norm_map: The mapping from feature types to normalization modes.
normalize_observation_keys: An optional set of observation keys to normalize.
eps: A small epsilon value for numerical stability.
device: The target device for the processor.
Returns:
A new instance of `NormalizerProcessorStep`.
"""
return cls(
features=features,
norm_map=norm_map,
@@ -211,16 +350,22 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
"""
A processor that applies unnormalization (the inverse of normalization) to
observations and actions in a transition.
A processor step that applies unnormalization to observations and actions.
This is typically used to transform actions from a normalized policy output back into
the original scale for execution in an environment.
This class inverts the normalization process, scaling data back to its original
range. It is typically used in the post-processing pipeline to convert a policy's
normalized action output into a format that can be executed by a robot or
environment.
"""
@classmethod
@@ -231,7 +376,19 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
norm_map: dict[FeatureType, NormalizationMode],
*,
device: torch.device | str | None = None,
) -> UnnormalizerProcessor:
) -> UnnormalizerProcessorStep:
"""
Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`.
Args:
dataset: The dataset from which to extract normalization statistics.
features: The feature definition for the processor.
norm_map: The mapping from feature types to normalization modes.
device: The target device for the processor.
Returns:
A new instance of `UnnormalizerProcessorStep`.
"""
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -249,20 +406,35 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
"""
Replaces normalization statistics in a RobotProcessor pipeline.
This function creates a deep copy of the provided `RobotProcessor` and updates the
statistics of any `NormalizerProcessor` or `UnnormalizerProcessor` steps within it.
It's useful for adapting a trained policy to a new environment or dataset with
different data distributions.
def hotswap_stats(
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
) -> PolicyProcessorPipeline:
"""
rp = deepcopy(robot_processor)
Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance.
This function creates a deep copy of the provided pipeline and updates the
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it
contains. This is useful for adapting a trained policy to a new environment or
dataset with different data distributions without having to reconstruct the entire
pipeline.
Args:
policy_processor: The policy processor pipeline to modify.
stats: The new dictionary of normalization statistics to apply.
Returns:
A new `PolicyProcessorPipeline` instance with the updated statistics.
"""
rp = deepcopy(policy_processor)
for step in rp.steps:
if isinstance(step, _NormalizationMixin):
step.stats = stats
# Re-initialize tensor_stats on the correct device.
step._tensor_stats = to_tensor(stats, device=step.device)
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype)
return rp
+99 -49
View File
@@ -20,32 +20,54 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor(ObservationProcessor):
class VanillaObservationProcessorStep(ObservationProcessorStep):
"""
Processes environment observations into the LeRobot format by handling both images and states.
Processes standard Gymnasium observations into the LeRobot format.
Image processing:
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
This step handles both image and state data from a typical observation dictionary,
preparing it for use in a LeRobot policy.
State processing:
- Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
- Converts numpy arrays to tensors
- Adds a batch dimension if missing
**Image Processing:**
- Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W),
`float32` tensors.
- Normalizes pixel values from the [0, 255] range to [0, 1].
- Adds a batch dimension if one is not already present.
- Recognizes a single image under the key `"pixels"` and maps it to
`"observation.image"`.
- Recognizes a dictionary of images under the key `"pixels"` and maps them
to `"observation.images.{camera_name}"`.
**State Processing:**
- Maps the `"environment_state"` key to `"observation.environment_state"`.
- Maps the `"agent_pos"` key to `"observation.state"`.
- Converts NumPy arrays to PyTorch tensors.
- Adds a batch dimension if one is not already present.
"""
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
"""
Processes a single NumPy image array into a channel-first, normalized tensor.
Args:
img: A NumPy array representing the image, expected to be in channel-last
(H, W, C) format with a `uint8` dtype.
Returns:
A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with
pixel values normalized to the [0, 1] range.
Raises:
ValueError: If the input image does not appear to be in channel-last
format or is not of `uint8` dtype.
"""
# Convert to tensor
img_tensor = torch.from_numpy(img)
@@ -106,18 +128,32 @@ class VanillaObservationProcessor(ObservationProcessor):
def observation(self, observation):
return self._process_observation(observation)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the Gym standard to the LeRobot standard.
This method standardizes the feature dictionary by renaming keys according
to LeRobot's conventions, ensuring that policies can be constructed correctly.
It handles various raw key formats, including those with an "observation." prefix.
**Renaming Rules:**
- `pixels` or `observation.pixels` -> `observation.image`
- `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}`
- `environment_state` or `observation.environment_state` -> `observation.environment_state`
- `agent_pos` or `observation.agent_pos` -> `observation.state`
Args:
features: The policy features dictionary with Gym-style keys.
Returns:
The policy features dictionary with standardized LeRobot keys.
"""
# Build a new features mapping keyed by the same FeatureType buckets
# We assume callers already placed features in the correct FeatureType.
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
exact_pairs = {
"pixels": OBS_IMAGE,
"environment_state": OBS_ENV_STATE,
@@ -128,29 +164,43 @@ class VanillaObservationProcessor(ObservationProcessor):
"pixels.": f"{OBS_IMAGES}.",
}
for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
# Iterate over all incoming feature buckets and normalize/move each entry
for src_ft, bucket in features.items():
for key, feat in list(bucket.items()):
handled = False
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if matched_prefix:
continue
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
if key in features:
features[new] = features.pop(key)
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
new_key = f"{new_prefix}{suffix}"
new_features[src_ft][new_key] = feat
handled = True
break
return features
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
new_key = f"{new_prefix}{suffix}"
new_features[src_ft][new_key] = feat
handled = True
break
if handled:
continue
# Exact-name rules (pixels, environment_state, agent_pos)
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
new_key = new
new_features[src_ft][new_key] = feat
handled = True
break
if handled:
continue
# Default: keep key in the same source FeatureType bucket
new_features[src_ft][key] = feat
return new_features
+132 -215
View File
@@ -22,48 +22,22 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Generic, TypedDict, TypeVar, cast
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file, save_file
from lerobot.configs.types import PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import batch_to_transition, create_transition, transition_to_batch
from .core import EnvTransition, TransitionKey
# Type variable for generic processor output type
TOutput = TypeVar("TOutput")
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
# TODO(Steven): Use consts
OBSERVATION = "observation"
ACTION = "action"
REWARD = "reward"
DONE = "done"
TRUNCATED = "truncated"
INFO = "info"
COMPLEMENTARY_DATA = "complementary_data"
EnvTransition = TypedDict(
"EnvTransition",
{
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
TransitionKey.ACTION.value: Any | torch.Tensor | None,
TransitionKey.REWARD.value: float | torch.Tensor | None,
TransitionKey.DONE.value: bool | torch.Tensor | None,
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
TransitionKey.INFO.value: dict[str, Any] | None,
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
},
)
class ProcessorStepRegistry:
"""Registry for processor steps that enables saving/loading by name instead of module path."""
@@ -142,7 +116,7 @@ class ProcessorStep(ABC):
A step is any callable accepting a full `EnvTransition` dict and
returning a (possibly modified) dict of the same structure. Implementers
are encouragedbut not requiredto expose the optional helper methods
listed below. When present, these hooks let `RobotProcessor`
listed below. When present, these hooks let `DataProcessorPipeline`
automatically serialise the step's configuration and learnable state using
a safe-to-share JSON + SafeTensors format.
@@ -194,107 +168,22 @@ class ProcessorStep(ABC):
def reset(self) -> None:
return None
# TODO(Steven): Consider making this abstract so it is more explicit
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
@abstractmethod
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
``EnvTransition`` dictionary.
The function maps well known keys to the EnvTransition structure. Missing keys are
filled with sane defaults (``None`` or ``0.0``/``False``).
Keys recognised (case-sensitive):
* "observation.*" (keys starting with "observation." are grouped into observation dict)
* "action"
* "next.reward"
* "next.done"
* "next.truncated"
* "info"
Additional keys are ignored so that existing dataloaders can carry extra
metadata without breaking the processor.
"""
# Validate input type
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation keys
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
# Extract padding, task, index, and task_index keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
complementary_data = (
{**pad_keys, **task_key, **index_key, **task_index_key}
if pad_keys or task_key or index_key or task_index_key
else {}
)
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: batch.get("action"),
TransitionKey.REWARD: batch.get("next.reward", 0.0),
TransitionKey.DONE: batch.get("next.done", False),
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
TransitionKey.INFO: batch.get("info", {}),
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
return transition
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
"""Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with
the canonical field names used throughout *LeRobot*.
"""
batch = {
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding, task, index, and task_index data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
batch.update(pad_data)
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
if "index" in complementary_data:
batch["index"] = complementary_data["index"]
if "task_index" in complementary_data:
batch["task_index"] = complementary_data["task_index"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
return batch
class ProcessorKwargs(TypedDict, total=False):
"""Keyword arguments for RobotProcessor constructor."""
"""Keyword arguments for DataProcessorPipeline constructor."""
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
to_output: Callable[[EnvTransition], Any] | None
@dataclass
class RobotProcessor(ModelHubMixin, Generic[TOutput]):
class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]):
"""
Composable, debuggable post-processing processor for robot transitions.
@@ -308,7 +197,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
Args:
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
name: Human-readable identifier that is persisted inside the JSON config.
Defaults to "RobotProcessor".
Defaults to "DataProcessorPipeline".
to_transition: Function to convert batch dict to EnvTransition dict.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition dict to the desired output format of type TOutput.
@@ -322,18 +211,20 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
Type Safety Examples:
```python
# Default behavior - returns batch dict
processor: RobotProcessor[dict[str, Any]] = RobotProcessor(steps=[some_step1, some_step2])
processor: DataProcessorPipeline[dict[str, Any]] = DataProcessorPipeline(
steps=[some_step1, some_step2]
)
result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict
# For EnvTransition output, explicitly specify identity function
transition_processor: RobotProcessor[EnvTransition] = RobotProcessor(
transition_processor: DataProcessorPipeline[EnvTransition] = DataProcessorPipeline(
steps=[some_step1, some_step2],
to_output=lambda x: x, # Identity function
)
result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition
# For custom output types
processor: RobotProcessor[str] = RobotProcessor(
processor: DataProcessorPipeline[str] = DataProcessorPipeline(
steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys"
)
result: str = processor(batch_data) # Type checker knows this is str
@@ -355,17 +246,15 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
"""
steps: Sequence[ProcessorStep] = field(default_factory=list)
name: str = "RobotProcessor"
name: str = "DataProcessorPipeline"
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False
)
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(default=batch_to_transition, repr=False)
to_output: Callable[[EnvTransition], TOutput] = field(
# Cast is necessary here: Working around Python type-checker limitation.
# _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput
# for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any],
# making this cast safe.
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch),
repr=False,
)
@@ -390,6 +279,12 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
# Always convert input through to_transition
transition = self.to_transition(data)
transformed_transition = self._forward(transition)
# Always use to_output for consistent typing
return self.to_output(transformed_transition)
def _forward(self, transition: EnvTransition) -> EnvTransition:
# Process through all steps
for idx, processor_step in enumerate(self.steps):
# Apply before hooks
@@ -402,9 +297,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
# Apply after hooks
for hook in self.after_step_hooks:
hook(idx, transition)
# Always use to_output for consistent typing
return self.to_output(transition)
return transition
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
"""Yield the intermediate results after each processor step.
@@ -529,7 +422,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
**kwargs,
) -> RobotProcessor[TOutput]:
) -> DataProcessorPipeline[TOutput]:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args:
@@ -537,7 +430,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
(e.g., "username/processor-name").
config_filename: Optional specific config filename to load. If not provided, will:
- For local paths: look for any .json file in the directory (error if multiple found)
- For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json")
- For HF Hub: REQUIRED - you must specify the exact config filename
overrides: Optional dictionary mapping step names to configuration overrides.
Keys must match exact step class names (for unregistered steps) or registry names
(for registered steps). Values are dictionaries containing parameter overrides
@@ -550,7 +443,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
Use identity function (lambda x: x) for EnvTransition output.
Returns:
A RobotProcessor[TOutput] instance loaded from the saved configuration.
A DataProcessorPipeline[TOutput] instance loaded from the saved configuration.
Raises:
ImportError: If a processor step class cannot be loaded or imported.
@@ -560,13 +453,13 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
Examples:
Basic loading:
```python
processor = RobotProcessor.from_pretrained("path/to/processor")
processor = DataProcessorPipeline.from_pretrained("path/to/processor")
```
Loading specific config file:
Loading from HF Hub (config_filename required):
```python
processor = RobotProcessor.from_pretrained(
"username/multi-processor-repo", config_filename="preprocessor.json"
processor = DataProcessorPipeline.from_pretrained(
"username/processor-repo", config_filename="processor.json"
)
```
@@ -575,14 +468,14 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
import gym
env = gym.make("CartPole-v1")
processor = RobotProcessor.from_pretrained(
processor = DataProcessorPipeline.from_pretrained(
"username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}}
)
```
Multiple overrides:
```python
processor = RobotProcessor.from_pretrained(
processor = DataProcessorPipeline.from_pretrained(
"path/to/processor",
overrides={
"CustomStep": {"param1": "new_value"},
@@ -594,7 +487,19 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
# Use the local variable name 'source' for clarity
source = str(pretrained_model_name_or_path)
if Path(source).is_dir():
# Check if it's a local path (either exists or looks like a filesystem path)
# Hub repositories are typically in the format "username/repo-name" (exactly one slash)
# Local paths are absolute paths, relative paths, or have more complex path structure
is_local_path = (
Path(source).is_dir()
or Path(source).is_absolute()
or source.startswith("./")
or source.startswith("../")
or source.count("/") > 1 # More than one slash suggests local path, not Hub repo
or "\\" in source # Windows-style paths are definitely local
)
if is_local_path:
# Local path - use it directly
base_path = Path(source)
@@ -613,57 +518,26 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
with open(base_path / config_filename) as file_pointer:
loaded_config: dict[str, Any] = json.load(file_pointer)
else:
# Hugging Face Hub - download all required files
# Hugging Face Hub - download specific config file
if config_filename is None:
# Try common config names
common_names = [
"robot_processor.json",
"robot_preprocessor.json",
"robot_postprocessor.json",
]
config_path = None
for name in common_names:
try:
config_path = hf_hub_download(
source,
name,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
config_filename = name
break
except (FileNotFoundError, OSError, HfHubHTTPError):
# FileNotFoundError: local file issues
# OSError: network/system errors
# HfHubHTTPError: file not found on Hub (404) or other HTTP errors
continue
if config_path is None:
raise FileNotFoundError(
f"No processor configuration file found in {source}. "
f"Tried: {common_names}. Please specify the config_filename parameter."
)
else:
# Download specific config file
config_path = hf_hub_download(
source,
config_filename,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
raise ValueError(
f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. "
f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')"
)
config_path = hf_hub_download(
source,
config_filename,
repo_type="model",
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
with open(config_path) as file_pointer:
loaded_config = json.load(file_pointer)
@@ -766,25 +640,25 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
return cls(
steps=steps,
name=loaded_config.get("name", "RobotProcessor"),
to_transition=to_transition or _default_batch_to_transition,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or batch_to_transition,
# Cast is necessary here: Same type-checker limitation as above.
# When to_output is None, we use the default which returns dict[str, Any].
# The cast ensures type consistency with the generic TOutput parameter.
to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
def __len__(self) -> int:
"""Return the number of steps in the processor."""
return len(self.steps)
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor[TOutput]:
def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TOutput]:
"""Indexing helper exposing underlying steps.
* ``int`` returns the idx-th ProcessorStep.
* ``slice`` returns a new RobotProcessor with the sliced steps.
* ``slice`` returns a new DataProcessorPipeline with the sliced steps.
"""
if isinstance(idx, slice):
return RobotProcessor(
return DataProcessorPipeline(
steps=self.steps[idx],
name=self.name,
to_transition=self.to_transition,
@@ -855,30 +729,68 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
parts = [f"name='{self.name}'", steps_repr]
return f"RobotProcessor({', '.join(parts)})"
return f"DataProcessorPipeline({', '.join(parts)})"
def __post_init__(self):
for i, step in enumerate(self.steps):
if not callable(step):
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
if not isinstance(step, ProcessorStep):
raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Apply ALL steps in order. Only if a step has a features method, it will be called.
We aggregate the dataset features of all steps.
"""
features: dict[str, PolicyFeature] = deepcopy(initial_features)
features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features)
for _, step in enumerate(self.steps):
out = step.transform_features(features)
features = out
return features
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
transition: EnvTransition = create_transition(observation=observation)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.OBSERVATION]
class ObservationProcessor(ProcessorStep, ABC):
def process_action(self, action: Any | torch.Tensor) -> Any | torch.Tensor:
transition: EnvTransition = create_transition(action=action)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.ACTION]
def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor:
transition: EnvTransition = create_transition(reward=reward)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.REWARD]
def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor:
transition: EnvTransition = create_transition(done=done)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.DONE]
def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor:
transition: EnvTransition = create_transition(truncated=truncated)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.TRUNCATED]
def process_info(self, info: dict[str, Any]) -> dict[str, Any]:
transition: EnvTransition = create_transition(info=info)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.INFO]
def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]:
transition: EnvTransition = create_transition(complementary_data=complementary_data)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.COMPLEMENTARY_DATA]
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TOutput]
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TOutput]
class ObservationProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the observation component of a transition.
Subclasses should override the `observation` method to implement custom observation processing.
@@ -924,7 +836,7 @@ class ObservationProcessor(ProcessorStep, ABC):
return new_transition
class ActionProcessor(ProcessorStep, ABC):
class ActionProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the action component of a transition.
Subclasses should override the `action` method to implement custom action processing.
@@ -971,7 +883,7 @@ class ActionProcessor(ProcessorStep, ABC):
return new_transition
class RewardProcessor(ProcessorStep, ABC):
class RewardProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the reward component of a transition.
Subclasses should override the `reward` method to implement custom reward processing.
@@ -1017,7 +929,7 @@ class RewardProcessor(ProcessorStep, ABC):
return new_transition
class DoneProcessor(ProcessorStep, ABC):
class DoneProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the done flag of a transition.
Subclasses should override the `done` method to implement custom done flag processing.
@@ -1068,7 +980,7 @@ class DoneProcessor(ProcessorStep, ABC):
return new_transition
class TruncatedProcessor(ProcessorStep, ABC):
class TruncatedProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the truncated flag of a transition.
Subclasses should override the `truncated` method to implement custom truncated flag processing.
@@ -1115,7 +1027,7 @@ class TruncatedProcessor(ProcessorStep, ABC):
return new_transition
class InfoProcessor(ProcessorStep, ABC):
class InfoProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the info dictionary of a transition.
Subclasses should override the `info` method to implement custom info processing.
@@ -1167,7 +1079,7 @@ class InfoProcessor(ProcessorStep, ABC):
return new_transition
class ComplementaryDataProcessor(ProcessorStep, ABC):
class ComplementaryDataProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the complementary data of a transition.
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
@@ -1200,8 +1112,13 @@ class ComplementaryDataProcessor(ProcessorStep, ABC):
return new_transition
class IdentityProcessor(ProcessorStep):
class IdentityProcessorStep(ProcessorStep):
"""Identity processor that does nothing."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+41 -11
View File
@@ -17,17 +17,26 @@ from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor(ObservationProcessor):
"""Rename processor that renames keys in the observation."""
@ProcessorStepRegistry.register(name="rename_observations_processor")
class RenameObservationsProcessorStep(ObservationProcessorStep):
"""
A processor step that renames keys in an observation dictionary.
This step is useful for creating a standardized data interface by mapping keys
from an environment's format to the format expected by a LeRobot policy or
other downstream components.
Attributes:
rename_map: A dictionary mapping from old key names to new key names.
Keys present in an observation that are not in this map will
be kept with their original names.
"""
rename_map: dict[str, str] = field(default_factory=dict)
@@ -44,16 +53,37 @@ class RenameProcessor(ObservationProcessor):
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.
"""
return {self.rename_map.get(k, k): v for k, v in features.items()}
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = features.copy()
new_features[PipelineFeatureType.OBSERVATION] = {
self.rename_map.get(k, k): v for k, v in features[PipelineFeatureType.OBSERVATION].items()
}
return new_features
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
"""
Renames the top-level keys in a statistics dictionary using a provided mapping.
This is a helper function typically used to keep normalization statistics
consistent with renamed observation or action features. It performs a defensive
deep copy to avoid modifying the original `stats` dictionary.
Args:
stats: A nested dictionary of statistics, where top-level keys are
feature names (e.g., `{"observation.state": {"mean": 0.5}}`).
rename_map: A dictionary mapping old feature names to new feature names.
Returns:
A new statistics dictionary with its top-level keys renamed. Returns an
empty dictionary if the input `stats` is empty.
"""
if not stats:
return {}
renamed: dict[str, dict[str, Any]] = {}
+114 -90
View File
@@ -1,5 +1,24 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tokenizer processor for handling text tokenization in robot transitions.
This script defines a processor for tokenizing natural language instructions from an environment transition.
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
token IDs and attention masks, which are then added to the observation dictionary.
"""
from __future__ import annotations
@@ -9,16 +28,14 @@ from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.processor.pipeline import (
EnvTransition,
ObservationProcessor,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
else:
@@ -27,68 +44,62 @@ else:
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessor(ObservationProcessor):
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
class TokenizerProcessorStep(ObservationProcessorStep):
"""
Processor step to tokenize a natural language task description.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
This step extracts a task string from the `complementary_data` of an `EnvTransition`,
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
token IDs and attention mask to the `observation` dictionary.
The processor supports both single strings and lists of strings as task inputs.
Requires the `transformers` library to be installed.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
```
Attributes:
tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased").
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
max_length: The maximum length to pad or truncate sequences to.
task_key: The key in `complementary_data` where the task string is stored.
padding_side: The side to pad on ('left' or 'right').
padding: The padding strategy ('max_length', 'longest', etc.).
truncation: Whether to truncate sequences longer than `max_length`.
input_tokenizer: The internal tokenizer instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
_tokenizer: Any = field(default=None, init=False, repr=False)
# Internal tokenizer instance (not part of the config)
input_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
"""
Initializes the tokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self._tokenizer = self.tokenizer
self.input_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoTokenizer is None:
raise ImportError("AutoTokenizer is not available")
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
@@ -96,13 +107,14 @@ class TokenizerProcessor(ObservationProcessor):
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
"""
Extracts the task description(s) from the transition's complementary data.
Args:
transition: Input transition containing complementary_data.
transition: The environment transition.
Returns:
List of task strings if task is present, None otherwise.
A list of task strings, or None if the task key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
@@ -115,7 +127,7 @@ class TokenizerProcessor(ObservationProcessor):
if task is None:
return None
# Convert to list of strings
# Standardize to a list of strings for the tokenizer
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
@@ -123,80 +135,82 @@ class TokenizerProcessor(ObservationProcessor):
return None
def observation(self, observation):
"""Process the transition by tokenizing the task text.
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
"""
Tokenizes the task description and adds it to the observation dictionary.
This method retrieves the task, tokenizes it, moves the resulting tensors to the
same device as other data in the transition, and updates the observation.
Args:
transition: Input transition containing complementary_data with task text.
observation: The original observation dictionary.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
The updated observation dictionary including token IDs and an attention mask.
"""
task = self.get_task(self.transition)
if task is None:
return observation
# Tokenize the task (creates CPU tensors)
# Tokenize the task (this will create CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
# Detect the device from existing tensors in the transition to ensure consistency
target_device = self._detect_device(self.transition)
# Move tokenized tensors to match the device of other data
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_prompt.items()
}
# Get or create observation dict
# Create a new observation dict to avoid modifying the original in place
new_observation = dict(observation)
# Add tokenized data to observation
# Add tokenized data to the observation
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
"""
Detects the torch.device from existing tensors in the transition.
This allows the tokenized tensors to match the device of other data,
which is especially important for multi-GPU training with Accelerate.
It checks tensors in the observation dictionary first, then the action tensor.
Args:
transition: The transition to search for existing tensors.
transition: The environment transition.
Returns:
The device of the first tensor found, or None if no tensors exist.
The detected `torch.device`, or None if no tensors are found.
"""
# Check observation tensors first (most likely to exist)
# Check observation tensors first (most likely place to find tensors)
observation = transition.get(TransitionKey.OBSERVATION)
if observation:
for value in observation.values():
if isinstance(value, torch.Tensor):
return value.device
# Check action tensor
# Fallback to checking the action tensor
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
return action.device
return None # No tensors found, keep on CPU
return None # No tensors found, default will be CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
"""
A wrapper around the tokenizer call.
Args:
text: Text string or list of strings to tokenize.
text: A string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
"""
return self._tokenizer(
return self.input_tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
@@ -206,10 +220,14 @@ class TokenizerProcessor(ObservationProcessor):
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
"""
Returns the serializable configuration of the processor.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"max_length": self.max_length,
@@ -219,30 +237,36 @@ class TokenizerProcessor(ObservationProcessor):
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Adds feature definitions for the language tokens and attention mask.
This updates the policy features dictionary to include the new data added to the
observation, ensuring downstream components are aware of their shape and type.
Args:
features: Input feature dictionary.
features: The dictionary of existing policy features.
Returns:
Updated feature dictionary with tokenized task features added.
The updated dictionary of policy features.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
# Add a feature for the token IDs if it doesn't already exist
if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
if OBS_LANGUAGE_TOKENS not in features:
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if OBS_LANGUAGE_ATTENTION_MASK not in features:
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
# Add a feature for the attention mask if it doesn't already exist
if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
+62 -24
View File
@@ -62,6 +62,7 @@ import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from pprint import pformat
from typing import Any
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
@@ -76,14 +77,20 @@ from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import (
to_dataset_frame,
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
from lerobot.processor import (
EnvTransition,
IdentityProcessorStep,
PolicyProcessorPipeline,
RobotProcessorPipeline,
TransitionKey,
)
from lerobot.processor.converters import (
action_to_transition,
identity_transition,
observation_to_transition,
transition_to_action,
transition_to_dataset_frame,
)
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
@@ -236,23 +243,36 @@ def record_loop(
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
preprocessor: RobotProcessor | None = None,
postprocessor: RobotProcessor | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
control_time_s: int | None = None,
teleop_action_processor: RobotProcessor | None = None, # runs after teleop
robot_action_processor: RobotProcessor | None = None, # runs before robot
robot_observation_processor: RobotProcessor | None = None, # runs after robot
teleop_action_processor: RobotProcessorPipeline[EnvTransition] | None = None, # runs after teleop
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None, # runs before robot
robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None, # runs after robot
single_task: str | None = None,
display_data: bool = False,
):
teleop_action_processor = teleop_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
teleop_action_processor
or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition
)
)
robot_action_processor = robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
robot_action_processor
or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=identity_transition,
to_output=transition_to_action,
)
)
robot_observation_processor = robot_observation_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
robot_observation_processor
or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=identity_transition,
)
)
if dataset is not None and dataset.fps != fps:
@@ -265,7 +285,14 @@ def record_loop(
(
t
for t in teleop
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
)
),
None,
)
@@ -308,7 +335,7 @@ def record_loop(
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
if dataset is not None:
observation_frame = to_dataset_frame(
observation_frame = transition_to_dataset_frame(
obs_transition, dataset.features
) # Convert the observation to the dataset format
@@ -334,6 +361,7 @@ def record_loop(
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
# TODO(Steven): This assumes that the processor passed by the user should have identity_transition as to_output.
teleop_transition = teleop_action_processor(act)
elif isinstance(teleop, list):
@@ -366,7 +394,7 @@ def record_loop(
# Write to dataset
if dataset is not None:
# If to_dataset_frame is provided, use it to merge the transitions.
# If transition_to_dataset_frame is provided, use it to merge the transitions.
merged = []
if obs_transition is not None: # The observation from the robot
merged.append(obs_transition)
@@ -374,13 +402,15 @@ def record_loop(
merged.append(teleop_transition)
if policy_transition is not None: # The action from policy
merged.append(policy_transition)
frame = to_dataset_frame(
frame = transition_to_dataset_frame(
merged if len(merged) > 1 else merged[0], dataset.features
) # Convert the observation to the dataset format
dataset.add_frame(frame, task=single_task)
if display_data:
log_rerun_data([obs_transition, teleop_transition or policy_transition])
log_rerun_data(
observation=obs_transition.get(TransitionKey.OBSERVATION), action=robot_action_to_send
)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
@@ -400,7 +430,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
dataset_features = {**action_features, **obs_features}
# Add next.* features that are generated during recording
transition_features = {
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
}
dataset_features = {**action_features, **obs_features, **transition_features}
if cfg.resume:
dataset = LeRobotDataset(
+8 -8
View File
@@ -47,9 +47,8 @@ from pprint import pformat
from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import RobotProcessor
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
from lerobot.processor.pipeline import IdentityProcessor
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_action
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -57,6 +56,7 @@ from lerobot.robots import ( # noqa: F401
hope_jr,
koch_follower,
make_robot_from_config,
reachy2,
so100_follower,
so101_follower,
)
@@ -86,7 +86,7 @@ class ReplayConfig:
# Use vocal synthesis to read events.
play_sounds: bool = True
# Optional processor for actions before sending to robot
robot_action_processor: RobotProcessor | None = None
robot_action_processor: RobotProcessorPipeline | None = None
@parser.wrap()
@@ -95,10 +95,10 @@ def replay(cfg: ReplayConfig):
logging.info(pformat(asdict(cfg)))
# Initialize robot action processor with default if not provided
robot_action_processor = cfg.robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()],
to_transition=to_transition_teleop_action,
to_output=to_output_robot_action, # type: ignore[arg-type]
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=action_to_transition,
to_output=transition_to_action, # type: ignore[arg-type]
)
# Reset processor
@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
# Optional
left_arm_disable_torque_on_disconnect: bool = True
left_arm_max_relative_target: int | None = None
left_arm_max_relative_target: float | dict[str, float] | None = None
left_arm_use_degrees: bool = False
right_arm_disable_torque_on_disconnect: bool = True
right_arm_max_relative_target: int | None = None
right_arm_max_relative_target: float | dict[str, float] | None = None
right_arm_use_degrees: bool = False
# cameras (shared between both arms)
+3 -3
View File
@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -110,6 +110,7 @@ class KochFollower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -120,7 +121,6 @@ class KochFollower(Robot):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
+3 -3
View File
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
+25
View File
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_reachy2 import Reachy2RobotConfig
from .robot_reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
)
@@ -0,0 +1,107 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("reachy2")
@dataclass
class Reachy2RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors.
max_relative_target: float | None = None
# IP address of the Reachy 2 robot
ip_address: str | None = "localhost"
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
disable_torque_on_disconnect: bool = False
# Tag for external commands control
# Set to True if you use an external commands system to control the robot,
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
# If True, robot.send_action() will not send commands to the robot.
use_external_commands: bool = False
# Robot parts
# Set to False to not add the corresponding joints part to the robot list of joints.
# By default, all parts are set to True.
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
# Robot cameras
# Set to True if you want to use the corresponding cameras in the observations.
# By default, only the teleop cameras are used.
with_left_teleop_camera: bool = True
with_right_teleop_camera: bool = True
with_torso_camera: bool = False
cameras: dict[str, CameraConfig] = field(default_factory=dict)
def __post_init__(self) -> None:
# Add cameras with same ip_address as the robot
if self.with_left_teleop_camera:
self.cameras["teleop_left"] = Reachy2CameraConfig(
name="teleop",
image_type="left",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_right_teleop_camera:
self.cameras["teleop_right"] = Reachy2CameraConfig(
name="teleop",
image_type="right",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
if self.with_torso_camera:
self.cameras["torso_rgb"] = Reachy2CameraConfig(
name="depth",
image_type="rgb",
ip_address=self.ip_address,
fps=15,
width=640,
height=480,
color_mode=ColorMode.RGB,
)
super().__post_init__()
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Robot part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)
+230
View File
@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any
import numpy as np
from reachy2_sdk import ReachySDK
from lerobot.cameras.utils import make_cameras_from_configs
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .configuration_reachy2 import Reachy2RobotConfig
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Robot(Robot):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2RobotConfig
name = "reachy2"
def __init__(self, config: Reachy2RobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.use_external_commands = self.config.use_external_commands
self.reachy: None | ReachySDK = None
self.cameras = make_cameras_from_configs(config.cameras)
self.logs: dict[str, float] = {}
self.joints_dict: dict[str, str] = self._generate_joints_dict()
@property
def observation_features(self) -> dict[str, Any]:
return {**self.motors_features, **self.camera_features}
@property
def action_features(self) -> dict[str, type]:
return self.motors_features
@property
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
@property
def motors_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = False) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.configure()
def configure(self) -> None:
if self.reachy is not None:
self.reachy.turn_on()
self.reachy.reset_default_limits()
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
def _get_state(self) -> dict[str, float]:
if self.reachy is not None:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
return {**pos_dict, **vel_dict}
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict: dict[str, Any] = {}
# Read Reachy 2 state
before_read_t = time.perf_counter()
obs_dict.update(self._get_state())
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras
for cam_key, cam in self.cameras.items():
obs_dict[cam_key] = cam.async_read()
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if self.reachy is not None:
if not self.is_connected:
raise ConnectionError()
before_write_t = time.perf_counter()
vel = {}
goal_pos = {}
for key, val in action.items():
if key not in self.joints_dict:
if key not in REACHY2_VEL:
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
else:
vel[REACHY2_VEL[key]] = float(val)
else:
if not self.use_external_commands and self.config.max_relative_target is not None:
goal_pos[key] = float(val)
goal_present_pos = {
key: (
goal_pos[key],
self.reachy.joints[self.joints_dict[key]].present_position,
)
}
safe_goal_pos = ensure_safe_goal_position(
goal_present_pos, float(self.config.max_relative_target)
)
val = safe_goal_pos[key]
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
if self.config.with_mobile_base:
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
return action
def disconnect(self) -> None:
if self.reachy is not None:
for cam in self.cameras.values():
cam.disconnect()
if self.config.disable_torque_on_disconnect:
self.reachy.turn_off_smoothly()
self.reachy.disconnect()
@@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
@@ -17,39 +17,48 @@
from dataclasses import dataclass, field
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_STATE
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.pipeline import (
ActionProcessor,
ComplementaryDataProcessor,
from lerobot.processor import (
ActionProcessorStep,
ComplementaryDataProcessorStep,
EnvTransition,
ObservationProcessor,
ObservationProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.robots.robot import Robot
from lerobot.utils.rotation import Rotation
@ProcessorStepRegistry.register("ee_reference_and_delta")
@dataclass
class EEReferenceAndDelta(ActionProcessor):
class EEReferenceAndDelta(ActionProcessorStep):
"""
Compute the desired end-effector pose from the target pose and the current pose.
Computes a target end-effector pose from a relative delta command.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
This step takes a desired change in position and orientation (`target_*`) and applies it to a
reference end-effector pose to calculate an absolute target pose. The reference pose is derived
from the current robot joint positions using forward kinematics.
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
The processor can operate in two modes:
1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action
is first enabled. Subsequent commands are relative to this fixed reference.
2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at
every step.
Attributes:
kinematics: The robot's kinematic model for forward kinematics.
end_effector_step_sizes: A dictionary scaling the input delta commands.
motor_names: A list of motor names required for forward kinematics.
use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the
current pose as the reference.
reference_ee_pose: Internal state storing the latched reference pose.
_prev_enabled: Internal state to detect the rising edge of the enable signal.
_command_when_disabled: Internal state to hold the last command while disabled.
"""
kinematics: RobotKinematics
@@ -82,13 +91,13 @@ class EEReferenceAndDelta(ActionProcessor):
# Current pose from FK on measured joints
t_curr = self.kinematics.forward_kinematics(q)
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
enabled = bool(new_action.pop("enabled", 0))
tx = float(new_action.pop("target_x", 0.0))
ty = float(new_action.pop("target_y", 0.0))
tz = float(new_action.pop("target_z", 0.0))
wx = float(new_action.pop("target_wx", 0.0))
wy = float(new_action.pop("target_wy", 0.0))
wz = float(new_action.pop("target_wz", 0.0))
desired = None
@@ -124,54 +133,57 @@ class EEReferenceAndDelta(ActionProcessor):
# Write action fields
pos = desired[:3, 3]
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
new_action[f"{ACTION}.ee.x"] = float(pos[0])
new_action[f"{ACTION}.ee.y"] = float(pos[1])
new_action[f"{ACTION}.ee.z"] = float(pos[2])
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
new_action["ee.x"] = float(pos[0])
new_action["ee.y"] = float(pos[1])
new_action["ee.z"] = float(pos[2])
new_action["ee.wx"] = float(tw[0])
new_action["ee.wy"] = float(tw[1])
new_action["ee.wz"] = float(tw[2])
self._prev_enabled = enabled
return new_action
def reset(self):
"""Resets the internal state of the processor."""
self._prev_enabled = False
self.reference_ee_pose = None
self._command_when_disabled = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.enabled", None)
features.pop(f"{ACTION}.target_x", None)
features.pop(f"{ACTION}.target_y", None)
features.pop(f"{ACTION}.target_z", None)
features.pop(f"{ACTION}.target_wx", None)
features.pop(f"{ACTION}.target_wy", None)
features.pop(f"{ACTION}.target_wz", None)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("enabled", None)
features[PipelineFeatureType.ACTION].pop("target_x", None)
features[PipelineFeatureType.ACTION].pop("target_y", None)
features[PipelineFeatureType.ACTION].pop("target_z", None)
features[PipelineFeatureType.ACTION].pop("target_wx", None)
features[PipelineFeatureType.ACTION].pop("target_wy", None)
features[PipelineFeatureType.ACTION].pop("target_wz", None)
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@ProcessorStepRegistry.register("ee_bounds_and_safety")
@dataclass
class EEBoundsAndSafety(ActionProcessor):
class EEBoundsAndSafety(ActionProcessorStep):
"""
Clip the end-effector pose to the bounds and check for jumps.
Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
This step ensures that the target end-effector pose remains within a safe operational workspace.
It also moderates the command to prevent large, sudden movements between consecutive steps.
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
Attributes:
end_effector_bounds: A dictionary with "min" and "max" keys for position clipping.
max_ee_step_m: The maximum allowed change in position (in meters) between steps.
max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps.
_last_pos: Internal state storing the last commanded position.
_last_twist: Internal state storing the last commanded orientation.
"""
end_effector_bounds: dict
@@ -181,15 +193,17 @@ class EEBoundsAndSafety(ActionProcessor):
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict) -> dict:
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
z = act.get(f"{ACTION}.ee.z", None)
wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get(f"{ACTION}.ee.wz", None)
x = act.get("ee.x", None)
y = act.get("ee.y", None)
z = act.get("ee.z", None)
wx = act.get("ee.wx", None)
wy = act.get("ee.wy", None)
wz = act.get("ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return act
raise ValueError(
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
)
pos = np.array([x, y, z], dtype=float)
twist = np.array([wx, wy, wz], dtype=float)
@@ -208,38 +222,40 @@ class EEBoundsAndSafety(ActionProcessor):
self._last_pos = pos
self._last_twist = twist
act[f"{ACTION}.ee.x"] = float(pos[0])
act[f"{ACTION}.ee.y"] = float(pos[1])
act[f"{ACTION}.ee.z"] = float(pos[2])
act[f"{ACTION}.ee.wx"] = float(twist[0])
act[f"{ACTION}.ee.wy"] = float(twist[1])
act[f"{ACTION}.ee.wz"] = float(twist[2])
act["ee.x"] = float(pos[0])
act["ee.y"] = float(pos[1])
act["ee.z"] = float(pos[2])
act["ee.wx"] = float(twist[0])
act["ee.wy"] = float(twist[1])
act["ee.wz"] = float(twist[2])
return act
def reset(self):
"""Resets the last known position and orientation."""
self._last_pos = None
self._last_twist = None
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
@dataclass
class InverseKinematicsEEToJoints(ProcessorStep):
"""
Compute the desired joint positions from the desired end-effector pose.
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
This step translates a Cartesian command (position and orientation of the end-effector) into
the corresponding joint-space commands for each motor.
Output ACTION keys:
{
"action.joint_name_1.pos": float,
"action.joint_name_2.pos": float,
...
"action.joint_name_n.pos": float,
}
Attributes:
kinematics: The robot's kinematic model for inverse kinematics.
motor_names: A list of motor names for which to compute joint positions.
q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
If False, use the solution from the previous step.
"""
kinematics: RobotKinematics
@@ -248,18 +264,19 @@ class InverseKinematicsEEToJoints(ProcessorStep):
initial_guess_current_joints: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
new_transition = transition.copy()
act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
z = act.get(f"{ACTION}.ee.z", None)
wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get(f"{ACTION}.ee.wz", None)
x = act.get("ee.x", None)
y = act.get("ee.y", None)
z = act.get("ee.z", None)
wx = act.get("ee.wx", None)
wy = act.get("ee.wy", None)
wz = act.get("ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return transition
return new_transition
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
@@ -286,23 +303,31 @@ class InverseKinematicsEEToJoints(ProcessorStep):
new_act = dict(act)
for i, name in enumerate(self.motor_names):
if name == "gripper":
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
# TODO(pepijn): Investigate if this is correct
# Do we want an observation key in the action field?
new_act["gripper.pos"] = float(raw["gripper"])
else:
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act
new_act[f"{name}.pos"] = float(q_target[i])
new_transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION]["gripper.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
for name in self.motor_names:
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features
def reset(self):
"""Resets the initial guess for the IK solver."""
self.q_curr = None
@@ -310,17 +335,18 @@ class InverseKinematicsEEToJoints(ProcessorStep):
@dataclass
class GripperVelocityToJoint(ProcessorStep):
"""
Convert the gripper velocity to a joint velocity.
Converts a gripper velocity command into a target gripper joint position.
Input ACTION keys:
{
"action.gripper": float,
}
This step integrates a normalized velocity command over time to produce a position command,
taking the current gripper position as a starting point. It also supports a discrete mode
where integer actions map to open, close, or no-op.
Output ACTION keys:
{
"action.gripper.pos": float,
}
Attributes:
motor_names: A list of motor names, which must include 'gripper'.
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
clip_min: The minimum allowed gripper joint position.
clip_max: The maximum allowed gripper joint position.
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
"""
motor_names: list[str]
@@ -330,67 +356,72 @@ class GripperVelocityToJoint(ProcessorStep):
discrete_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION) or {}
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if f"{ACTION}.gripper" not in act:
return transition
if "gripper" not in act:
raise ValueError("Required action key 'gripper' not found in transition")
if "gripper" not in self.motor_names:
new_act = dict(act)
new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act
return transition
raise ValueError(
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
)
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
gripper_action = act.get("gripper", 1.0)
gripper_action = gripper_action - 1.0
gripper_action *= self.clip_max
act[f"{ACTION}.gripper"] = gripper_action
act["gripper"] = gripper_action
# Get current gripper position from complementary data
raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity
u = float(act.get(f"{ACTION}.gripper", 0.0))
u = float(act.get("gripper", 0.0))
delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
new_act = dict(act)
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act
new_act["gripper.pos"] = gripper_pos
new_act.pop("gripper", None)
new_transition[TransitionKey.ACTION] = new_act
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
transition[TransitionKey.OBSERVATION] = obs
return transition
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("gripper", None)
features[PipelineFeatureType.ACTION]["gripper.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.gripper.pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(1,)
)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.gripper", None)
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
@dataclass
class ForwardKinematicsJointsToEE(ObservationProcessor):
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
"""
Compute the end-effector pose from the joint positions.
Computes the end-effector pose from joint positions using forward kinematics (FK).
Input OBSERVATION keys:
{
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
}
This step is typically used to add the robot's Cartesian pose to the observation space,
which can be useful for visualization or as an input to a policy.
Output OBSERVATION keys:
{
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
}
Attributes:
kinematics: The robot's kinematic model.
motor_names: A list of motor names whose joint positions are used for FK.
"""
kinematics: RobotKinematics
@@ -398,7 +429,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
def observation(self, obs: dict) -> dict:
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
return obs
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
@@ -413,21 +444,29 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
return obs
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.{k}"] = PolicyFeature(
type=FeatureType.STATE, shape=(1,)
)
return features
@ProcessorStepRegistry.register("add_robot_observation")
@dataclass
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
"""
Read the robot's current observation and insert it into the transition as complementary data.
Reads the robot's current observation and adds it to the transition's complementary data.
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
{ "<motor_name>": <float position>, ... }
This step acts as a bridge to the physical robot, injecting its real-time sensor readings
(like raw joint positions) into the data processing pipeline. This data is then available
for other processing steps.
Attributes:
robot: An instance of a `Robot` class used to get observations from hardware.
"""
robot: Robot
@@ -444,3 +483,8 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
if isinstance(k, str) and k.endswith(".pos")
}
return new_comp
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -24,11 +24,6 @@ from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
+5 -1
View File
@@ -57,6 +57,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .bi_so100_follower import BiSO100Follower
return BiSO100Follower(config)
elif config.type == "reachy2":
from .reachy2 import Reachy2Robot
return Reachy2Robot(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot
@@ -67,7 +71,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
def ensure_safe_goal_position(
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
) -> dict[str, float]:
"""Caps relative action target magnitude for safety."""
+3 -3
View File
@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
# names to the max_relative_target value for that motor.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: int | None = 5
max_relative_target: float | dict[str, float] = 5.0
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
+45 -14
View File
@@ -56,6 +56,8 @@ from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Any
from typing import Any
import einops
import gymnasium as gym
@@ -69,9 +71,11 @@ from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.core import TransitionKey
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -84,6 +88,10 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -120,7 +128,6 @@ def rollout(
The dictionary described above.
"""
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
device = get_device_from_parameters(policy)
# Reset the policy and environments.
policy.reset()
@@ -151,19 +158,18 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Convert to CPU / numpy.
action = action.to("cpu").numpy()
action: np.ndarray = action.to("cpu").numpy()
action: np.ndarray = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
@@ -220,6 +226,10 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -296,8 +306,14 @@ def eval_policy(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env,
policy,
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
@@ -479,13 +495,28 @@ def eval_main(cfg: EvalPipelineConfig):
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
cfg.eval.n_episodes,
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
+6 -16
View File
@@ -62,7 +62,7 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.processor.pipeline import TransitionKey
from lerobot.processor import TransitionKey
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl.gym_manipulator import (
create_transition,
@@ -98,9 +98,7 @@ from lerobot.utils.utils import (
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
# Main entry point
@parser.wrap()
@@ -207,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
logging.info("[ACTOR] queues closed")
#################################################
# Core algorithm functions #
#################################################
# Core algorithm functions
def act_with_policy(
@@ -406,9 +402,7 @@ def act_with_policy(
busy_wait(1 / cfg.env.fps - dt_time)
#################################################
# Communication Functions - Group all gRPC/messaging functions #
#################################################
# Communication Functions - Group all gRPC/messaging functions
def establish_learner_connection(
@@ -653,9 +647,7 @@ def interactions_stream(
return services_pb2.Empty()
#################################################
# Policy functions #
#################################################
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
@@ -687,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
#################################################
# Utilities functions #
#################################################
# Utilities functions
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
+53 -60
View File
@@ -29,25 +29,27 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs.configs import HILSerlRobotEnvConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
AddTeleopActionAsComplimentaryData,
AddTeleopEventsAsInfo,
DeviceProcessor,
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
JointVelocityProcessor,
MapDeltaActionToRobotAction,
MapTensorToDeltaActionDict,
MotorCurrentProcessor,
Numpy2TorchActionProcessor,
RewardClassifierProcessor,
RobotProcessor,
TimeLimitProcessor,
ToBatchProcessor,
Torch2NumpyActionProcessor,
VanillaObservationProcessor,
AddBatchDimensionProcessorStep,
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
DataProcessorPipeline,
DeviceProcessorStep,
EnvTransition,
GripperPenaltyProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
JointVelocityProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
MotorCurrentProcessorStep,
Numpy2TorchActionProcessorStep,
RewardClassifierProcessorStep,
TimeLimitProcessorStep,
Torch2NumpyActionProcessorStep,
TransitionKey,
VanillaObservationProcessorStep,
create_transition,
)
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
@@ -98,21 +100,6 @@ class GymManipulatorConfig:
device: str = "cpu"
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
) -> dict[str, Any]:
"""Create an EnvTransition dictionary with default values."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
"""Reset robot arm to target position using smooth trajectory."""
current_position_dict = robot_arm.bus.sync_read("Present_Position")
@@ -375,19 +362,21 @@ def make_processors(
if cfg.name == "gym_hil":
action_pipeline_steps = [
InterventionActionProcessor(terminate_on_success=terminate_on_success),
Torch2NumpyActionProcessor(),
InterventionActionProcessorStep(terminate_on_success=terminate_on_success),
Torch2NumpyActionProcessorStep(),
]
# Minimal processor pipeline for GymHIL simulation
env_pipeline_steps = [
Numpy2TorchActionProcessor(),
VanillaObservationProcessor(),
ToBatchProcessor(),
DeviceProcessor(device=device),
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
return RobotProcessor(steps=env_pipeline_steps), RobotProcessor(steps=action_pipeline_steps)
return DataProcessorPipeline(steps=env_pipeline_steps), DataProcessorPipeline(
steps=action_pipeline_steps
)
# Full processor pipeline for real robot environment
# Get robot and motor information for kinematics
@@ -402,13 +391,13 @@ def make_processors(
joint_names=motor_names,
)
env_pipeline_steps = [VanillaObservationProcessor()]
env_pipeline_steps = [VanillaObservationProcessorStep()]
if cfg.processor.observation is not None:
if cfg.processor.observation.add_joint_velocity_to_observation:
env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps))
env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps))
if cfg.processor.observation.add_current_to_observation:
env_pipeline_steps.append(MotorCurrentProcessor(robot=env.robot))
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
if kinematics_solver is not None:
env_pipeline_steps.append(
@@ -420,7 +409,7 @@ def make_processors(
if cfg.processor.image_preprocessing is not None:
env_pipeline_steps.append(
ImageCropResizeProcessor(
ImageCropResizeProcessorStep(
crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict,
resize_size=cfg.processor.image_preprocessing.resize_size,
)
@@ -429,13 +418,13 @@ def make_processors(
# Add time limit processor if reset config exists
if cfg.processor.reset is not None:
env_pipeline_steps.append(
TimeLimitProcessor(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
)
# Add gripper penalty processor if gripper config exists and enabled
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
env_pipeline_steps.append(
GripperPenaltyProcessor(
GripperPenaltyProcessorStep(
penalty=cfg.processor.gripper.gripper_penalty,
max_gripper_pos=cfg.processor.max_gripper_pos,
)
@@ -446,7 +435,7 @@ def make_processors(
and cfg.processor.reward_classifier.pretrained_path is not None
):
env_pipeline_steps.append(
RewardClassifierProcessor(
RewardClassifierProcessorStep(
pretrained_path=cfg.processor.reward_classifier.pretrained_path,
device=device,
success_threshold=cfg.processor.reward_classifier.success_threshold,
@@ -455,14 +444,14 @@ def make_processors(
)
)
env_pipeline_steps.append(ToBatchProcessor())
env_pipeline_steps.append(DeviceProcessor(device=device))
env_pipeline_steps.append(AddBatchDimensionProcessorStep())
env_pipeline_steps.append(DeviceProcessorStep(device=device))
action_pipeline_steps = [
AddTeleopActionAsComplimentaryData(teleop_device=teleop_device),
AddTeleopEventsAsInfo(teleop_device=teleop_device),
AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device),
AddTeleopEventsAsInfoStep(teleop_device=teleop_device),
AddRobotObservationAsComplimentaryData(robot=env.robot),
InterventionActionProcessor(
InterventionActionProcessorStep(
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False,
terminate_on_success=terminate_on_success,
),
@@ -472,8 +461,10 @@ def make_processors(
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
# Add EE bounds and safety processor
inverse_kinematics_steps = [
MapTensorToDeltaActionDict(),
MapDeltaActionToRobotAction(),
MapTensorToDeltaActionDictStep(
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False
),
MapDeltaActionToRobotActionStep(),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes,
@@ -497,15 +488,15 @@ def make_processors(
]
action_pipeline_steps.extend(inverse_kinematics_steps)
return RobotProcessor(steps=env_pipeline_steps), RobotProcessor(steps=action_pipeline_steps)
return DataProcessorPipeline(steps=env_pipeline_steps), DataProcessorPipeline(steps=action_pipeline_steps)
def step_env_and_process_transition(
env: gym.Env,
transition: EnvTransition,
action: torch.Tensor,
env_processor: RobotProcessor,
action_processor: RobotProcessor,
env_processor: DataProcessorPipeline,
action_processor: DataProcessorPipeline,
):
"""
Execute one step with processor pipeline.
@@ -554,8 +545,8 @@ def step_env_and_process_transition(
def control_loop(
env: gym.Env,
env_processor: RobotProcessor,
action_processor: RobotProcessor,
env_processor: DataProcessorPipeline,
action_processor: DataProcessorPipeline,
teleop_device: Teleoperator,
cfg: GymManipulatorConfig,
) -> None:
@@ -709,7 +700,9 @@ def control_loop(
dataset.push_to_hub()
def replay_trajectory(env: gym.Env, action_processor: RobotProcessor, cfg: GymManipulatorConfig) -> None:
def replay_trajectory(
env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig
) -> None:
"""Replay recorded trajectory on robot environment."""
assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
+3 -14
View File
@@ -103,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger
LOG_PREFIX = "[LEARNER]"
#################################################
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
#################################################
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
if not use_threads(cfg):
@@ -250,9 +245,7 @@ def start_learner_threads(
logging.info("[LEARNER] queues closed")
#################################################
# Core algorithm functions #
#################################################
# Core algorithm functions
def add_actor_information_and_train(
@@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M
return optimizers, lr_scheduler
#################################################
# Training setup functions #
#################################################
# Training setup functions
def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig:
@@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer(
return offline_replay_buffer
#################################################
# Utilities/Helpers functions #
#################################################
# Utilities/Helpers functions
def get_observation_features(
+46 -6
View File
@@ -26,7 +26,7 @@ from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
@@ -65,6 +65,28 @@ def update_policy(
use_amp: bool = False,
lock=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
Args:
train_metrics: A MetricsTracker instance to record training statistics.
policy: The policy model to be trained.
batch: A batch of training data.
optimizer: The optimizer used to update the policy's parameters.
grad_clip_norm: The maximum norm for gradient clipping.
grad_scaler: The GradScaler for automatic mixed-precision training.
lr_scheduler: An optional learning rate scheduler.
use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates.
Returns:
A tuple containing:
- The updated MetricsTracker with new statistics for this step.
- A dictionary of outputs from the policy's forward pass, for logging purposes.
"""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
@@ -108,6 +130,20 @@ def update_policy(
@parser.wrap()
def train(cfg: TrainPipelineConfig):
"""
Main function to train a policy.
This function orchestrates the entire training pipeline, including:
- Setting up logging, seeding, and device configuration.
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
- Handling resumption from a checkpoint.
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
- Pushing the final trained model to the Hugging Face Hub if configured.
Args:
cfg: A `TrainPipelineConfig` object containing all training configurations.
"""
cfg.validate()
logging.info(pformat(cfg.to_dict()))
@@ -153,9 +189,11 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
preprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
)
postprocessor.from_pretrained(
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
@@ -260,9 +298,11 @@ def train(cfg: TrainPipelineConfig):
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
env=eval_env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
+51 -20
View File
@@ -55,19 +55,20 @@ import logging
import time
from dataclasses import asdict, dataclass
from pprint import pformat
from typing import Any
import rerun as rr
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.processor import RobotProcessor
from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey
from lerobot.processor.converters import (
to_output_robot_action,
to_transition_robot_observation,
to_transition_teleop_action,
action_to_transition,
identity_transition,
observation_to_transition,
transition_to_action,
)
from lerobot.processor.pipeline import IdentityProcessor
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -105,9 +106,9 @@ class TeleoperateConfig:
# Display all cameras on screen
display_data: bool = False
# Optional processors for data transformation
teleop_action_processor: RobotProcessor | None = None # runs after teleop
robot_action_processor: RobotProcessor | None = None # runs before robot
robot_observation_processor: RobotProcessor | None = None # runs after robot
teleop_action_processor: RobotProcessorPipeline | None = None # runs after teleop
robot_action_processor: RobotProcessorPipeline | None = None # runs before robot
robot_observation_processor: RobotProcessorPipeline | None = None # runs after robot
def teleop_loop(
@@ -116,21 +117,47 @@ def teleop_loop(
fps: int,
display_data: bool = False,
duration: float | None = None,
teleop_action_processor: RobotProcessor | None = None,
robot_action_processor: RobotProcessor | None = None,
robot_observation_processor: RobotProcessor | None = None,
teleop_action_processor: RobotProcessorPipeline[EnvTransition] | None = None,
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None,
robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None,
):
"""
This function continuously reads actions from a teleoperation device, processes them through optional
pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
specified frequency until a set duration is reached or it is manually interrupted.
Args:
teleop: The teleoperator device instance providing control actions.
robot: The robot instance being controlled.
fps: The target frequency for the control loop in frames per second.
display_data: If True, fetches robot observations and displays them in the console and Rerun.
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
robot_observation_processor: An optional pipeline to process raw observations from the robot.
"""
# Initialize processors with defaults if not provided
teleop_action_processor = teleop_action_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
teleop_action_processor
or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition
)
)
robot_action_processor = robot_action_processor or RobotProcessor(
steps=[IdentityProcessor()],
to_transition=lambda tr: tr,
to_output=to_output_robot_action, # type: ignore[arg-type]
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
robot_action_processor
or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=identity_transition,
to_output=transition_to_action, # type: ignore[arg-type]
)
)
robot_observation_processor = robot_observation_processor or RobotProcessor(
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
robot_observation_processor
or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=identity_transition,
)
)
# Reset processors
@@ -161,7 +188,11 @@ def teleop_loop(
obs = robot.get_observation()
# Process robot observation through pipeline
obs_transition = robot_observation_processor(obs)
log_rerun_data([obs_transition, teleop_transition])
log_rerun_data(
observation=obs_transition.get(TransitionKey.OBSERVATION),
action=teleop_transition.get(TransitionKey.ACTION),
)
print("\n" + "-" * (display_len + 10))
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
@@ -88,6 +88,7 @@ class KochLeader(Teleoperator):
return self.bus.is_calibrated
def calibrate(self) -> None:
self.bus.disable_torque()
if self.calibration:
# Calibration file exists, ask user whether to use it or run new calibration
user_input = input(
@@ -98,7 +99,6 @@ class KochLeader(Teleoperator):
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
@@ -16,45 +16,53 @@
from dataclasses import dataclass, field
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.constants import ACTION
from lerobot.processor import ActionProcessorStep, ProcessorStepRegistry
from lerobot.teleoperators.phone.config_phone import PhoneOS
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
@dataclass
class MapPhoneActionToRobotAction(ActionProcessor):
class MapPhoneActionToRobotAction(ActionProcessorStep):
"""
Map calibrated phone pose (actions) to the inputs for robot actions
Maps calibrated phone pose actions to standardized robot action inputs.
Expected input ACTION keys:
{
"action.phone.enabled": bool,
"action.phone.pos": np.ndarray,
"action.phone.rot": Rotation,
"action.phone.raw_inputs": dict,
}
This processor step acts as a bridge between the phone teleoperator's output
and the robot's expected action format. It remaps the phone's 6-DoF pose
(position and rotation) to the robot's target end-effector pose, applying
necessary axis inversions and swaps. It also interprets platform-specific
button presses to generate a gripper command.
Output ACTION keys:
{
"action.enabled": bool,
"action.ee.{x,y,z,wx,wy,wz}" : float
"action.gripper": float,
}
Attributes:
platform: The operating system of the phone (iOS or Android), used
to determine the correct button mappings for the gripper.
"""
platform: PhoneOS
_enabled_prev: bool = field(default=False, init=False, repr=False)
def action(self, act: dict) -> dict:
"""
Processes the phone action dictionary to create a robot action dictionary.
Args:
act: The input action dictionary from the phone teleoperator.
Returns:
A new action dictionary formatted for the robot controller.
Raises:
ValueError: If 'pos' or 'rot' keys are missing from the input action.
"""
# Pop them from the action
enabled = bool(act.pop("action.phone.enabled", 0))
pos = act.pop("action.phone.pos", None)
rot = act.pop("action.phone.rot", None)
inputs = act.pop("action.phone.raw_inputs", {})
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
pos = act.pop(f"{ACTION}.phone.pos", None)
rot = act.pop(f"{ACTION}.phone.rot", None)
inputs = act.pop(f"{ACTION}.phone.raw_inputs", {})
if pos is None or rot is None:
return act
raise ValueError("pos and rot must be present in action")
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
@@ -69,28 +77,30 @@ class MapPhoneActionToRobotAction(ActionProcessor):
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
# For some actions we need to invert the axis
act["action.enabled"] = enabled
act["action.target_x"] = -pos[1] if enabled else 0.0
act["action.target_y"] = pos[0] if enabled else 0.0
act["action.target_z"] = pos[2] if enabled else 0.0
act["action.target_wx"] = rotvec[1] if enabled else 0.0
act["action.target_wy"] = rotvec[0] if enabled else 0.0
act["action.target_wz"] = -rotvec[2] if enabled else 0.0
act["action.gripper"] = gripper # Still send gripper action when disabled
act[f"{ACTION}.enabled"] = enabled
act[f"{ACTION}.target_x"] = -pos[1] if enabled else 0.0
act[f"{ACTION}.target_y"] = pos[0] if enabled else 0.0
act[f"{ACTION}.target_z"] = pos[2] if enabled else 0.0
act[f"{ACTION}.target_wx"] = rotvec[1] if enabled else 0.0
act[f"{ACTION}.target_wy"] = rotvec[0] if enabled else 0.0
act[f"{ACTION}.target_wz"] = -rotvec[2] if enabled else 0.0
act[f"{ACTION}.gripper"] = gripper # Still send gripper action when disabled
return act
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop("action.phone.enabled", None)
features.pop("action.phone.pos", None)
features.pop("action.phone.rot", None)
features.pop("action.phone.raw_inputs", None)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("phone.enabled", None)
features[PipelineFeatureType.ACTION].pop("phone.pos", None)
features[PipelineFeatureType.ACTION].pop("phone.rot", None)
features[PipelineFeatureType.ACTION].pop("phone.raw_inputs", None)
features["action.enabled"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.target_wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.gripper"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
+84 -22
View File
@@ -24,12 +24,12 @@ import time
import hebi
import numpy as np
from scipy.spatial.transform import Rotation
from teleop import Teleop
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.utils.rotation import Rotation
logger = logging.getLogger(__name__)
@@ -101,32 +101,56 @@ class IOSPhone(BasePhone, Teleoperator):
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
)
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
pos, rot = self._wait_for_capture_trigger()
self._calib_pos = pos.copy()
self._calib_rot_inv = rot.inv()
position, rotation = self._wait_for_capture_trigger()
self._calib_pos = position.copy()
self._calib_rot_inv = rotation.inv()
self._enabled = False
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
"""
Blocks execution until the calibration trigger is detected from the iOS device.
This method enters a loop, continuously reading the phone's state. It waits for the user to press
and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and
returns the phone's pose at that exact moment.
Returns:
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
moment the trigger was activated.
"""
while True:
ok, pos, rot, pose = self._read_current_pose()
if not ok:
has_pose, position, rotation, fb_pose = self._read_current_pose()
if not has_pose:
time.sleep(0.01)
continue
io = getattr(pose, "io", None)
b = getattr(io, "b", None) if io is not None else None
b1 = False
if b is not None:
b1 = bool(b.get_int(1))
if b1:
return pos, rot
io = getattr(fb_pose, "io", None)
button_b = getattr(io, "b", None) if io is not None else None
button_b1_pressed = False
if button_b is not None:
button_b1_pressed = bool(button_b.get_int(1))
if button_b1_pressed:
return position, rotation
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
"""
Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK.
This method fetches the latest feedback packet from the HEBI group, extracts the ARKit
position and orientation, and converts them into a standard format. It also applies a
configured camera offset to adjust the pose from the camera's frame to the phone's
physical frame.
Returns:
A tuple containing:
- A boolean indicating if a valid pose was successfully read.
- The 3D position as a NumPy array, or None if not available.
- The orientation as a `Rotation` object, or None if not available.
- The raw HEBI feedback object for accessing other data like button presses.
"""
fbk = self._group.get_next_feedback()
pose = fbk[0]
ar_pos = getattr(pose, "ar_position", None)
@@ -141,13 +165,13 @@ class IOSPhone(BasePhone, Teleoperator):
return True, pos, rot, pose
def get_action(self) -> dict:
ok, raw_pos, raw_rot, pose = self._read_current_pose()
if not ok or not self.is_calibrated:
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
if not has_pose or not self.is_calibrated:
return {}
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
raw_inputs: dict[str, float | int | bool] = {}
io = getattr(pose, "io", None)
io = getattr(fb_pose, "io", None)
if io is not None:
bank_a, bank_b = io.a, io.b
if bank_a:
@@ -165,11 +189,11 @@ class IOSPhone(BasePhone, Teleoperator):
# Rising edge then re-capture calibration immediately from current raw pose
if enable and not self._enabled:
self._reapply_position_calibration(raw_pos)
self._reapply_position_calibration(raw_position)
# Apply calibration
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rot
pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rotation
self._enabled = enable
@@ -229,7 +253,18 @@ class AndroidPhone(BasePhone, Teleoperator):
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
"""
Blocks execution until the calibration trigger is detected from the Android device.
This method enters a loop, continuously checking the latest message received from the WebXR
session. It waits for the user to touch and move their finger on the screen, which generates
a `move` event. Once this event is detected, the loop breaks and returns the phone's current
pose.
Returns:
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
moment the trigger was activated.
"""
while True:
with self._android_lock:
msg = self._latest_message or {}
@@ -242,6 +277,20 @@ class AndroidPhone(BasePhone, Teleoperator):
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
"""
Reads the latest 6-DoF pose received from the Android device's WebXR session.
This method accesses the most recent pose data stored by the `_android_callback`. It uses a
thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is
then decomposed into position and rotation, and the configured camera offset is applied.
Returns:
A tuple containing:
- A boolean indicating if a valid pose was available.
- The 3D position as a NumPy array, or None if no pose has been received yet.
- The orientation as a `Rotation` object, or None if no pose has been received.
- The raw 4x4 pose matrix as received from the teleop stream.
"""
with self._android_lock:
if self._latest_pose is None:
return False, None, None, None
@@ -252,6 +301,19 @@ class AndroidPhone(BasePhone, Teleoperator):
return True, pos, rot, pose
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
"""
Callback function to handle incoming data from the Android teleop stream.
This method is executed by the `teleop` package's subscriber thread whenever a new
pose and message are received from the WebXR session on the Android phone. It updates
the internal state (`_latest_pose` and `_latest_message`) with the new data.
A thread lock is used to ensure that these shared variables are updated atomically,
preventing race conditions with the main thread that reads them.
Args:
pose: A 4x4 NumPy array representing the phone's transformation matrix.
message: A dictionary containing additional data, such as button presses or touch events.
"""
with self._android_lock:
self._latest_pose = pose
self._latest_message = message
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
from .reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Teleoperator,
)
@@ -0,0 +1,51 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
@dataclass
class Reachy2TeleoperatorConfig(TeleoperatorConfig):
# IP address of the Reachy 2 robot used as teleoperator
ip_address: str | None = "localhost"
# Whether to use the present position of the joints as actions
# if False, the goal position of the joints will be used
use_present_position: bool = False
# Which parts of the robot to use
with_mobile_base: bool = True
with_l_arm: bool = True
with_r_arm: bool = True
with_neck: bool = True
with_antennas: bool = True
def __post_init__(self):
if not (
self.with_mobile_base
or self.with_l_arm
or self.with_r_arm
or self.with_neck
or self.with_antennas
):
raise ValueError(
"No Reachy2Teleoperator part used.\n"
"At least one part of the robot must be set to True "
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
)
@@ -0,0 +1,164 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from reachy2_sdk import ReachySDK
from ..teleoperator import Teleoperator
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
logger = logging.getLogger(__name__)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw",
"neck_pitch.pos": "head.neck.pitch",
"neck_roll.pos": "head.neck.roll",
}
REACHY2_ANTENNAS_JOINTS = {
"l_antenna.pos": "head.l_antenna",
"r_antenna.pos": "head.r_antenna",
}
REACHY2_R_ARM_JOINTS = {
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
"r_wrist_roll.pos": "r_arm.wrist.roll",
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
"r_gripper.pos": "r_arm.gripper",
}
REACHY2_L_ARM_JOINTS = {
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
"l_wrist_roll.pos": "l_arm.wrist.roll",
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
"l_gripper.pos": "l_arm.gripper",
}
REACHY2_VEL = {
"mobile_base.vx": "vx",
"mobile_base.vy": "vy",
"mobile_base.vtheta": "vtheta",
}
class Reachy2Teleoperator(Teleoperator):
"""
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
"""
config_class = Reachy2TeleoperatorConfig
name = "reachy2_specific"
def __init__(self, config: Reachy2TeleoperatorConfig):
super().__init__(config)
self.config = config
self.reachy: None | ReachySDK = None
self.joints_dict: dict[str, str] = self._generate_joints_dict()
def _generate_joints_dict(self) -> dict[str, str]:
joints = {}
if self.config.with_neck:
joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
joints.update(REACHY2_ANTENNAS_JOINTS)
return joints
@property
def action_features(self) -> dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
self.joints_dict.keys(),
float,
),
**dict.fromkeys(
REACHY2_VEL.keys(),
float,
),
}
else:
return dict.fromkeys(self.joints_dict.keys(), float)
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = True) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
raise ConnectionError()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
def configure(self) -> None:
pass
def get_action(self) -> dict[str, float]:
start = time.perf_counter()
if self.reachy and self.is_connected:
if self.config.use_present_position:
joint_action = {
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
}
else:
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return joint_action
if self.config.use_present_position:
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
else:
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return {**joint_action, **vel_action}
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError
def disconnect(self) -> None:
if self.reachy and self.is_connected:
self.reachy.disconnect()
+4
View File
@@ -77,5 +77,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .bi_so100_leader import BiSO100Leader
return BiSO100Leader(config)
elif config.type == "reachy2_teleoperator":
from .reachy2_teleoperator import Reachy2Teleoperator
return Reachy2Teleoperator(config)
else:
raise ValueError(config.type)
+90 -4
View File
@@ -31,11 +31,25 @@ from termcolor import colored
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor, TransitionKey
from lerobot.processor import PolicyProcessorPipeline, TransitionKey
from lerobot.robots import Robot
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
"""
Logs performance metrics for a single step of the robot control loop.
This function formats and prints a single line of log information, including episode/frame counters,
total loop time (dt), and detailed timings for various robot and camera operations. It can also
highlight performance drops in yellow if the actual FPS is lower than the target FPS.
Args:
robot: The `Robot` instance, used to access its internal logs for detailed timings.
dt_s: The total duration of the control loop step in seconds.
episode_index: The index of the current episode.
frame_index: The index of the current frame within the episode.
fps: The target frames per second, used to check for performance degradation.
"""
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -81,7 +95,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
@cache
def is_headless():
"""Detects if python is running without a monitor."""
"""
Detects if the Python script is running in a headless environment (e.g., without a display).
This function attempts to import `pynput`, a library that requires a graphical environment.
If the import fails, it assumes the environment is headless. The result is cached to avoid
re-running the check.
Returns:
True if the environment is determined to be headless, False otherwise.
"""
try:
import pynput # noqa
@@ -102,12 +125,35 @@ def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
preprocessor: RobotProcessor,
postprocessor: RobotProcessor,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
use_amp: bool,
task: str | None = None,
robot_type: str | None = None,
):
"""
Performs a single-step inference to predict a robot action from an observation.
This function encapsulates the full inference pipeline:
1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension.
2. Runs the preprocessor pipeline on the observation.
3. Feeds the processed observation to the policy to get a raw action.
4. Runs the postprocessor pipeline on the raw action.
5. Formats the final action by removing the batch dimension and moving it to the CPU.
Args:
observation: A dictionary of NumPy arrays representing the robot's current observation.
policy: The `PreTrainedPolicy` model to use for action prediction.
device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on.
preprocessor: The `PolicyProcessorPipeline` for preprocessing observations.
postprocessor: The `PolicyProcessorPipeline` for postprocessing actions.
use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference.
task: An optional string identifier for the task.
robot_type: An optional string identifier for the robot type.
Returns:
A `torch.Tensor` containing the predicted action, ready for the robot.
"""
observation = copy(observation)
with (
torch.inference_mode(),
@@ -143,6 +189,18 @@ def predict_action(
def init_keyboard_listener():
"""
Initializes a non-blocking keyboard listener for real-time user interaction.
This function sets up a listener for specific keys (right arrow, left arrow, escape) to control
the program flow during execution, such as stopping recording or exiting loops. It gracefully
handles headless environments where keyboard listening is not possible.
Returns:
A tuple containing:
- The `pynput.keyboard.Listener` instance, or `None` if in a headless environment.
- A dictionary of event flags (e.g., `exit_early`) that are set by key presses.
"""
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
@@ -184,6 +242,19 @@ def init_keyboard_listener():
def sanity_check_dataset_name(repo_id, policy_cfg):
"""
Validates the dataset repository name against the presence of a policy configuration.
This function enforces a naming convention: a dataset repository ID should start with "eval_"
if and only if a policy configuration is provided for evaluation purposes.
Args:
repo_id: The Hugging Face Hub repository ID of the dataset.
policy_cfg: The configuration object for the policy, or `None`.
Raises:
ValueError: If the naming convention is violated.
"""
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
@@ -204,6 +275,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
) -> None:
"""
Checks if a dataset's metadata is compatible with the current robot and recording setup.
This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the
dataset against the current configuration to ensure that appended data will be consistent.
Args:
dataset: The `LeRobotDataset` instance to check.
robot: The `Robot` instance representing the current hardware setup.
fps: The current recording frequency (frames per second).
features: The dictionary of features for the current recording session.
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
+174
View File
@@ -0,0 +1,174 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom rotation utilities to replace scipy.spatial.transform.Rotation."""
import numpy as np
class Rotation:
"""
Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality.
Supports conversions between rotation vectors, rotation matrices, and quaternions.
"""
def __init__(self, quat: np.ndarray) -> None:
"""Initialize rotation from quaternion [x, y, z, w]."""
self._quat = np.asarray(quat, dtype=float)
# Normalize quaternion
norm = np.linalg.norm(self._quat)
if norm > 0:
self._quat = self._quat / norm
@classmethod
def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation":
"""
Create rotation from rotation vector using Rodrigues' formula.
Args:
rotvec: Rotation vector [x, y, z] where magnitude is angle in radians
Returns:
Rotation instance
"""
rotvec = np.asarray(rotvec, dtype=float)
angle = np.linalg.norm(rotvec)
if angle < 1e-8:
# For very small angles, use identity quaternion
quat = np.array([0.0, 0.0, 0.0, 1.0])
else:
axis = rotvec / angle
half_angle = angle / 2.0
sin_half = np.sin(half_angle)
cos_half = np.cos(half_angle)
# Quaternion [x, y, z, w]
quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half])
return cls(quat)
@classmethod
def from_matrix(cls, matrix: np.ndarray) -> "Rotation":
"""
Create rotation from 3x3 rotation matrix.
Args:
matrix: 3x3 rotation matrix
Returns:
Rotation instance
"""
matrix = np.asarray(matrix, dtype=float)
# Shepherd's method for converting rotation matrix to quaternion
trace = np.trace(matrix)
if trace > 0:
s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw
qw = 0.25 * s
qx = (matrix[2, 1] - matrix[1, 2]) / s
qy = (matrix[0, 2] - matrix[2, 0]) / s
qz = (matrix[1, 0] - matrix[0, 1]) / s
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx
qw = (matrix[2, 1] - matrix[1, 2]) / s
qx = 0.25 * s
qy = (matrix[0, 1] + matrix[1, 0]) / s
qz = (matrix[0, 2] + matrix[2, 0]) / s
elif matrix[1, 1] > matrix[2, 2]:
s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy
qw = (matrix[0, 2] - matrix[2, 0]) / s
qx = (matrix[0, 1] + matrix[1, 0]) / s
qy = 0.25 * s
qz = (matrix[1, 2] + matrix[2, 1]) / s
else:
s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz
qw = (matrix[1, 0] - matrix[0, 1]) / s
qx = (matrix[0, 2] + matrix[2, 0]) / s
qy = (matrix[1, 2] + matrix[2, 1]) / s
qz = 0.25 * s
quat = np.array([qx, qy, qz, qw])
return cls(quat)
@classmethod
def from_quat(cls, quat: np.ndarray) -> "Rotation":
"""
Create rotation from quaternion.
Args:
quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring)
This implementation expects [x, y, z, w] format
Returns:
Rotation instance
"""
return cls(quat)
def as_matrix(self) -> np.ndarray:
"""
Convert rotation to 3x3 rotation matrix.
Returns:
3x3 rotation matrix
"""
qx, qy, qz, qw = self._quat
# Compute rotation matrix from quaternion
return np.array(
[
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],
[2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)],
[2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)],
],
dtype=float,
)
def as_rotvec(self) -> np.ndarray:
"""
Convert rotation to rotation vector.
Returns:
Rotation vector [x, y, z] where magnitude is angle in radians
"""
qx, qy, qz, qw = self._quat
# Ensure qw is positive for unique representation
if qw < 0:
qx, qy, qz, qw = -qx, -qy, -qz, -qw
# Compute angle and axis
angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0))
sin_half_angle = np.sqrt(1.0 - qw * qw)
if sin_half_angle < 1e-8:
# For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz]
return 2.0 * np.array([qx, qy, qz])
# Extract axis and scale by angle
axis = np.array([qx, qy, qz]) / sin_half_angle
return angle * axis
def as_quat(self) -> np.ndarray:
"""
Get quaternion representation.
Returns:
Quaternion [x, y, z, w]
"""
return self._quat.copy()
+3 -3
View File
@@ -32,7 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.random_utils import load_rng_state, save_rng_state
@@ -75,8 +75,8 @@ def save_checkpoint(
policy: PreTrainedPolicy,
optimizer: Optimizer,
scheduler: LRScheduler | None = None,
preprocessor: RobotProcessor | None = None,
postprocessor: RobotProcessor | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
) -> None:
"""This function creates the following directory structure:
+49 -69
View File
@@ -19,8 +19,6 @@ from typing import Any
import numpy as np
import rerun as rr
from lerobot.processor.pipeline import EnvTransition, TransitionKey
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
"""Initializes the Rerun SDK for visualizing the control loop."""
@@ -33,85 +31,67 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
def _is_scalar(x):
return (
isinstance(x, numbers.Real)
isinstance(x, float)
or isinstance(x, numbers.Real)
or isinstance(x, (np.integer, np.floating))
or (isinstance(x, np.ndarray) and x.ndim == 0)
)
def log_rerun_data(
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
*,
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
) -> None:
items = data if isinstance(data, list) else ([data] if data is not None else [])
"""
Logs observation and action data to Rerun for real-time visualization.
obs = {} if observation is None else dict(observation)
act = {} if action is None else dict(action)
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalar values (floats, ints) are logged as `rr.Scalar`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format and logged as `rr.Image`.
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
- Other multi-dimensional arrays are flattened and logged as individual scalars.
for idx, item in enumerate(items):
if not isinstance(item, dict):
continue
Keys are automatically namespaced with "observation." or "action." if not already present.
if any(isinstance(k, TransitionKey) for k in item.keys()):
o = item.get(TransitionKey.OBSERVATION) or {}
a = item.get(TransitionKey.ACTION) or {}
if isinstance(o, dict):
obs.update(o)
if isinstance(a, dict):
act.update(a)
continue
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
"""
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith("observation.") else f"observation.{k}"
keys = list(item.keys())
has_obs = any(str(k).startswith("observation.") for k in keys)
has_act = any(str(k).startswith("action.") for k in keys)
if _is_scalar(v):
rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else:
rr.log(key, rr.Image(arr), static=True)
if has_obs or has_act:
if has_obs:
obs.update(item)
if has_act:
act.update(item)
else:
# No prefixes: assume first is observation, second is action, others are observation
if idx == 0:
obs.update(item)
elif idx == 1:
act.update(item)
else:
obs.update(item)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith("action.") else f"action.{k}"
for k, v in obs.items():
if v is None:
continue
key = k if str(k).startswith("observation.") else f"observation.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else:
rr.log(key, rr.Image(arr), static=True)
for k, v in act.items():
if v is None:
continue
key = k if str(k).startswith("action.") else f"action.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray):
if v.ndim == 1:
for i, vi in enumerate(v):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else:
# Fall back to flattening higher-dimensional arrays
flat = v.flatten()
for i, vi in enumerate(flat):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
if _is_scalar(v):
rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray):
if v.ndim == 1:
for i, vi in enumerate(v):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else:
# Fall back to flattening higher-dimensional arrays
flat = v.flatten()
for i, vi in enumerate(flat):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
+177
View File
@@ -0,0 +1,177 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
from lerobot.errors import DeviceNotConnectedError
PARAMS = [
("teleop", "left"),
("teleop", "right"),
("depth", "rgb"),
# ("depth", "depth"), # Depth camera is not available yet
]
def _make_cam_manager_mock():
c = MagicMock(name="CameraManagerMock")
teleop = MagicMock(name="TeleopCam")
teleop.width = 640
teleop.height = 480
teleop.get_frame = MagicMock(
side_effect=lambda *_, **__: (
np.zeros((480, 640, 3), dtype=np.uint8),
time.time(),
)
)
depth = MagicMock(name="DepthCam")
depth.width = 640
depth.height = 480
depth.get_frame = MagicMock(
side_effect=lambda *_, **__: (
np.zeros((480, 640, 3), dtype=np.uint8),
time.time(),
)
)
c.is_connected.return_value = True
c.teleop = teleop
c.depth = depth
def _connect():
c.teleop = teleop
c.depth = depth
c.is_connected.return_value = True
def _disconnect():
c.teleop = None
c.depth = None
c.is_connected.return_value = False
c.connect = MagicMock(side_effect=_connect)
c.disconnect = MagicMock(side_effect=_disconnect)
# Mock methods
c.initialize_cameras = MagicMock()
return c
@pytest.fixture(
params=PARAMS,
# ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"],
ids=["teleop-left", "teleop-right", "torso-rgb"],
)
def camera(request):
name, image_type = request.param
with (
patch(
"lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager",
side_effect=lambda *a, **k: _make_cam_manager_mock(),
),
):
config = Reachy2CameraConfig(name=name, image_type=image_type)
cam = Reachy2Camera(config)
yield cam
if cam.is_connected:
cam.disconnect()
def test_connect(camera):
camera.connect()
assert camera.is_connected
camera.cam_manager.initialize_cameras.assert_called_once()
def test_read(camera):
camera.connect()
img = camera.read()
if camera.config.name == "teleop":
camera.cam_manager.teleop.get_frame.assert_called_once()
elif camera.config.name == "depth":
camera.cam_manager.depth.get_frame.assert_called_once()
assert isinstance(img, np.ndarray)
assert img.shape == (480, 640, 3)
def test_disconnect(camera):
camera.connect()
camera.disconnect()
assert not camera.is_connected
def test_async_read(camera):
camera.connect()
try:
img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray)
finally:
if camera.is_connected:
camera.disconnect()
def test_async_read_timeout(camera):
camera.connect()
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
def test_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
_ = camera.read()
def test_disconnect_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
camera.disconnect()
def test_async_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError):
_ = camera.async_read()
def test_wrong_camera_name():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
def test_wrong_image_type():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="teleop", image_type="rgb")
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="depth", image_type="left")
def test_wrong_color_mode():
with pytest.raises(ValueError):
_ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color")
+7 -4
View File
@@ -19,7 +19,7 @@ import traceback
import pytest
from serial import SerialException
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from tests.utils import DEVICE
# Import fixture modules as plugins
@@ -28,6 +28,7 @@ pytest_plugins = [
"tests.fixtures.files",
"tests.fixtures.hub",
"tests.fixtures.optimizers",
"tests.plugins.reachy2_sdk",
]
@@ -82,7 +83,9 @@ def policy_feature_factory():
return _pf
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None:
assert isinstance(features, dict)
assert all(isinstance(k, str) for k in features.keys())
assert all(isinstance(v, PolicyFeature) for v in features.values())
assert all(isinstance(k, PipelineFeatureType) for k in features.keys())
assert all(isinstance(v, dict) for v in features.values())
assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values())
assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values())
+6 -6
View File
@@ -20,7 +20,7 @@ from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
def test_default_parameters():
@@ -72,7 +72,7 @@ def test_merge_simple_vectors():
}
}
out = merge_features(g1, g2)
out = combine_feature_dicts(g1, g2)
assert "action" in out
assert out["action"]["dtype"] == "float32"
@@ -87,7 +87,7 @@ def test_merge_multiple_groups_order_and_dedup():
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
out = merge_features(g1, g2, g3)
out = combine_feature_dicts(g1, g2, g3)
assert out["action"]["names"] == ["a", "b", "c", "d"]
assert out["action"]["shape"] == (4,)
@@ -110,7 +110,7 @@ def test_non_vector_last_wins_for_images():
}
}
out = merge_features(g1, g2)
out = combine_feature_dicts(g1, g2)
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
assert out["observation.images.front"]["dtype"] == "image"
@@ -120,13 +120,13 @@ def test_dtype_mismatch_raises():
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
_ = merge_features(g1, g2)
_ = combine_feature_dicts(g1, g2)
def test_non_dict_passthrough_last_wins():
g1 = {"misc": 123}
g2 = {"misc": 456}
out = merge_features(g1, g2)
out = combine_feature_dicts(g1, g2)
# For non-dict entries the last one wins
assert out["misc"] == 456
+30
View File
@@ -0,0 +1,30 @@
import sys
import types
from unittest.mock import MagicMock
def _install_reachy2_sdk_stub():
sdk = types.ModuleType("reachy2_sdk")
sdk.__path__ = []
sdk.ReachySDK = MagicMock(name="ReachySDK")
media = types.ModuleType("reachy2_sdk.media")
media.__path__ = []
camera = types.ModuleType("reachy2_sdk.media.camera")
camera.CameraView = MagicMock(name="CameraView")
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
camera_manager.CameraManager = MagicMock(name="CameraManager")
sdk.media = media
media.camera = camera
media.camera_manager = camera_manager
# Register in sys.modules
sys.modules.setdefault("reachy2_sdk", sdk)
sys.modules.setdefault("reachy2_sdk.media", media)
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
def pytest_sessionstart(session):
_install_reachy2_sdk_stub()
+86 -19
View File
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.processor_act import make_act_pre_post_processors
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
DeviceProcessorStep,
NormalizerProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(observation=None, action=None, **kwargs):
@@ -81,20 +81,20 @@ def test_make_act_processor_basic():
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_act_processor_normalization():
@@ -250,7 +250,7 @@ def test_act_processor_save_and_load():
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -303,11 +303,22 @@ def test_act_processor_mixed_precision():
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessor):
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
@@ -353,3 +364,59 @@ def test_act_processor_batch_consistency():
processed_batched = preprocessor(transition_batched)
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_act_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
action = torch.randn(4, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
+18 -22
View File
@@ -1,11 +1,7 @@
import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
from lerobot.processor import DataProcessorPipeline, TransitionKey
from lerobot.processor.converters import batch_to_transition, transition_to_batch
def _dummy_batch():
@@ -24,7 +20,7 @@ def _dummy_batch():
def test_observation_grouping_roundtrip():
"""Test that observation.* keys are properly grouped and ungrouped."""
proc = RobotProcessor([])
proc = DataProcessorPipeline([])
batch_in = _dummy_batch()
batch_out = proc(batch_in)
@@ -48,7 +44,7 @@ def test_observation_grouping_roundtrip():
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
"""Test that batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
@@ -60,7 +56,7 @@ def test_batch_to_transition_observation_grouping():
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
@@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping():
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
"""Test that transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
@@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening():
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
batch = transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
@@ -134,7 +130,7 @@ def test_no_observation_keys():
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionKey.OBSERVATION] is None
@@ -147,7 +143,7 @@ def test_no_observation_keys():
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
@@ -159,7 +155,7 @@ def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
@@ -173,7 +169,7 @@ def test_minimal_batch():
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
@@ -186,7 +182,7 @@ def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
transition = batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionKey.OBSERVATION] is None
@@ -198,7 +194,7 @@ def test_empty_batch():
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
@@ -219,8 +215,8 @@ def test_complex_nested_observation():
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
transition = batch_to_transition(batch)
reconstructed_batch = transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")}
@@ -254,7 +250,7 @@ def test_custom_converter():
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
tr = batch_to_transition(batch)
# Double the reward
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
@@ -262,10 +258,10 @@ def test_custom_converter():
return new_tr
def to_batch(tr):
batch = _default_transition_to_batch(tr)
batch = transition_to_batch(tr)
return batch
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
batch = {
"observation.state": torch.randn(1, 4),
+66 -63
View File
@@ -22,9 +22,12 @@ import pytest
import torch
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import ProcessorStepRegistry, RobotProcessor
from lerobot.processor.batch_processor import ToBatchProcessor
from lerobot.processor.pipeline import TransitionKey
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
ProcessorStepRegistry,
TransitionKey,
)
def create_transition(
@@ -44,7 +47,7 @@ def create_transition(
def test_state_1d_to_2d():
"""Test that 1D state tensors get unsqueezed to 2D."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test observation.state
state_1d = torch.randn(7)
@@ -60,7 +63,7 @@ def test_state_1d_to_2d():
def test_env_state_1d_to_2d():
"""Test that 1D environment state tensors get unsqueezed to 2D."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test observation.environment_state
env_state_1d = torch.randn(10)
@@ -76,7 +79,7 @@ def test_env_state_1d_to_2d():
def test_image_3d_to_4d():
"""Test that 3D image tensors get unsqueezed to 4D."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test observation.image
image_3d = torch.randn(224, 224, 3)
@@ -92,7 +95,7 @@ def test_image_3d_to_4d():
def test_multiple_images_3d_to_4d():
"""Test that 3D image tensors in observation.images.* get unsqueezed to 4D."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test observation.images.camera1 and observation.images.camera2
image1_3d = torch.randn(64, 64, 3)
@@ -117,7 +120,7 @@ def test_multiple_images_3d_to_4d():
def test_already_batched_tensors_unchanged():
"""Test that already batched tensors remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create already batched tensors
state_2d = torch.randn(1, 7)
@@ -143,7 +146,7 @@ def test_already_batched_tensors_unchanged():
def test_higher_dimensional_tensors_unchanged():
"""Test that tensors with more dimensions than expected remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create tensors with more dimensions
state_3d = torch.randn(2, 7, 5) # More than 1D
@@ -166,7 +169,7 @@ def test_higher_dimensional_tensors_unchanged():
def test_non_tensor_values_unchanged():
"""Test that non-tensor values in observations remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
observation = {
OBS_STATE: [1, 2, 3], # List, not tensor
@@ -189,7 +192,7 @@ def test_non_tensor_values_unchanged():
def test_none_observation():
"""Test processor handles None observation gracefully."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
transition = create_transition(observation=None)
result = processor(transition)
@@ -199,7 +202,7 @@ def test_none_observation():
def test_empty_observation():
"""Test processor handles empty observation dict."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
observation = {}
transition = create_transition(observation=observation)
@@ -211,7 +214,7 @@ def test_empty_observation():
def test_mixed_observation():
"""Test processor with mixed observation containing various types and dimensions."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
state_1d = torch.randn(5)
env_state_2d = torch.randn(1, 8) # Already batched
@@ -243,9 +246,9 @@ def test_mixed_observation():
def test_integration_with_robot_processor():
"""Test ToBatchProcessor integration with RobotProcessor."""
to_batch_processor = ToBatchProcessor()
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
"""Test AddBatchDimensionProcessorStep integration with RobotProcessor."""
to_batch_processor = AddBatchDimensionProcessorStep()
pipeline = DataProcessorPipeline([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
# Create unbatched observation
observation = {
@@ -263,7 +266,7 @@ def test_integration_with_robot_processor():
def test_serialization_methods():
"""Test get_config, state_dict, load_state_dict, and reset methods."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test get_config
config = processor.get_config()
@@ -283,9 +286,9 @@ def test_serialization_methods():
def test_save_and_load_pretrained():
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
processor = ToBatchProcessor()
pipeline = RobotProcessor(
"""Test saving and loading AddBatchDimensionProcessorStep with RobotProcessor."""
processor = AddBatchDimensionProcessorStep()
pipeline = DataProcessorPipeline(
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
)
@@ -298,13 +301,13 @@ def test_save_and_load_pretrained():
assert config_path.exists()
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(
loaded_pipeline = DataProcessorPipeline.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
assert loaded_pipeline.name == "BatchPipeline"
assert len(loaded_pipeline) == 1
assert isinstance(loaded_pipeline.steps[0], ToBatchProcessor)
assert isinstance(loaded_pipeline.steps[0], AddBatchDimensionProcessorStep)
# Test functionality of loaded processor
observation = {OBS_STATE: torch.randn(5)}
@@ -315,10 +318,10 @@ def test_save_and_load_pretrained():
def test_registry_functionality():
"""Test that ToBatchProcessor is properly registered."""
"""Test that AddBatchDimensionProcessorStep is properly registered."""
# Check that the processor is registered
registered_class = ProcessorStepRegistry.get("to_batch_processor")
assert registered_class is ToBatchProcessor
assert registered_class is AddBatchDimensionProcessorStep
# Check that it's in the list of registered processors
assert "to_batch_processor" in ProcessorStepRegistry.list()
@@ -326,12 +329,12 @@ def test_registry_functionality():
def test_registry_based_save_load():
"""Test saving and loading using registry name."""
processor = ToBatchProcessor()
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
processor = AddBatchDimensionProcessorStep()
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(
loaded_pipeline = DataProcessorPipeline.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -352,7 +355,7 @@ def test_registry_based_save_load():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_device_compatibility():
"""Test processor works with tensors on different devices."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create tensors on GPU
state_1d = torch.randn(7, device="cuda")
@@ -376,7 +379,7 @@ def test_device_compatibility():
def test_processor_preserves_other_transition_keys():
"""Test that processor only modifies observation and preserves other transition keys."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
action = torch.randn(5)
reward = 1.5
@@ -413,7 +416,7 @@ def test_processor_preserves_other_transition_keys():
def test_edge_case_zero_dimensional_tensors():
"""Test processor handles 0D tensors (scalars) correctly."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# 0D tensors should not be modified
scalar_tensor = torch.tensor(42.0)
@@ -435,7 +438,7 @@ def test_edge_case_zero_dimensional_tensors():
# Action-specific tests
def test_action_1d_to_2d():
"""Test that 1D action tensors get batch dimension added."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create 1D action tensor
action_1d = torch.randn(4)
@@ -450,7 +453,7 @@ def test_action_1d_to_2d():
def test_action_already_batched():
"""Test that already batched action tensors remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test various batch sizes
action_batched_1 = torch.randn(1, 4)
@@ -469,7 +472,7 @@ def test_action_already_batched():
def test_action_higher_dimensional():
"""Test that higher dimensional action tensors remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# 3D action tensor (e.g., sequence of actions)
action_3d = torch.randn(2, 4, 3)
@@ -486,7 +489,7 @@ def test_action_higher_dimensional():
def test_action_scalar_tensor():
"""Test that scalar (0D) action tensors remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
action_scalar = torch.tensor(1.5)
transition = create_transition(action=action_scalar)
@@ -499,7 +502,7 @@ def test_action_scalar_tensor():
def test_action_non_tensor():
"""Test that non-tensor actions remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# List action
action_list = [0.1, 0.2, 0.3, 0.4]
@@ -528,7 +531,7 @@ def test_action_non_tensor():
def test_action_none():
"""Test that None action is handled correctly."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
transition = create_transition(action=None)
result = processor(transition)
@@ -537,7 +540,7 @@ def test_action_none():
def test_action_with_observation():
"""Test action processing together with observation processing."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Both need batching
observation = {
@@ -557,7 +560,7 @@ def test_action_with_observation():
def test_action_different_sizes():
"""Test action processing with various action dimensions."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Different action sizes (robot with different DOF)
action_sizes = [1, 2, 4, 7, 10, 20]
@@ -574,7 +577,7 @@ def test_action_different_sizes():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_action_device_compatibility():
"""Test action processing on different devices."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# CUDA action
action_cuda = torch.randn(4, device="cuda")
@@ -595,7 +598,7 @@ def test_action_device_compatibility():
def test_action_dtype_preservation():
"""Test that action dtype is preserved during processing."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Different dtypes
dtypes = [torch.float32, torch.float64, torch.int32, torch.int64]
@@ -611,7 +614,7 @@ def test_action_dtype_preservation():
def test_empty_action_tensor():
"""Test handling of empty action tensors."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Empty 1D tensor
action_empty = torch.tensor([])
@@ -633,7 +636,7 @@ def test_empty_action_tensor():
# Task-specific tests
def test_task_string_to_list():
"""Test that string tasks get wrapped in lists to add batch dimension."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create complementary data with string task
complementary_data = {"task": "pick_cube"}
@@ -650,7 +653,7 @@ def test_task_string_to_list():
def test_task_string_validation():
"""Test that only string and list of strings are valid task values."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Valid string task - should be converted to list
complementary_data = {"task": "valid_task"}
@@ -669,7 +672,7 @@ def test_task_string_validation():
def test_task_list_of_strings():
"""Test that lists of strings remain unchanged (already batched)."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test various list of strings
test_lists = [
@@ -695,7 +698,7 @@ def test_task_list_of_strings():
def test_complementary_data_none():
"""Test processor handles None complementary_data gracefully."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
transition = create_transition(complementary_data=None)
result = processor(transition)
@@ -705,7 +708,7 @@ def test_complementary_data_none():
def test_complementary_data_empty():
"""Test processor handles empty complementary_data dict."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
complementary_data = {}
transition = create_transition(complementary_data=complementary_data)
@@ -717,7 +720,7 @@ def test_complementary_data_empty():
def test_complementary_data_no_task():
"""Test processor handles complementary_data without task field."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
complementary_data = {
"episode_id": 123,
@@ -735,7 +738,7 @@ def test_complementary_data_no_task():
def test_complementary_data_mixed():
"""Test processor with mixed complementary_data containing task and other fields."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
complementary_data = {
"task": "stack_blocks",
@@ -760,7 +763,7 @@ def test_complementary_data_mixed():
def test_task_with_observation_and_action():
"""Test task processing together with observation and action processing."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# All components need batching
observation = {
@@ -785,7 +788,7 @@ def test_task_with_observation_and_action():
def test_task_comprehensive_string_cases():
"""Test task processing with comprehensive string cases and edge cases."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test various string formats
string_tasks = [
@@ -843,7 +846,7 @@ def test_task_comprehensive_string_cases():
def test_task_preserves_other_keys():
"""Test that task processing preserves other keys in complementary_data."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
complementary_data = {
"task": "clean_table",
@@ -871,7 +874,7 @@ def test_task_preserves_other_keys():
# Index and task_index specific tests
def test_index_scalar_to_1d():
"""Test that 0D index tensor gets unsqueezed to 1D."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create 0D index tensor (scalar)
index_0d = torch.tensor(42, dtype=torch.int64)
@@ -888,7 +891,7 @@ def test_index_scalar_to_1d():
def test_task_index_scalar_to_1d():
"""Test that 0D task_index tensor gets unsqueezed to 1D."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create 0D task_index tensor (scalar)
task_index_0d = torch.tensor(7, dtype=torch.int64)
@@ -905,7 +908,7 @@ def test_task_index_scalar_to_1d():
def test_index_and_task_index_together():
"""Test processing both index and task_index together."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create 0D tensors for both
index_0d = torch.tensor(100, dtype=torch.int64)
@@ -935,7 +938,7 @@ def test_index_and_task_index_together():
def test_index_already_batched():
"""Test that already batched index tensors remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create already batched tensors
index_1d = torch.tensor([42], dtype=torch.int64)
@@ -956,7 +959,7 @@ def test_index_already_batched():
def test_task_index_already_batched():
"""Test that already batched task_index tensors remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create already batched tensors
task_index_1d = torch.tensor([7], dtype=torch.int64)
@@ -977,7 +980,7 @@ def test_task_index_already_batched():
def test_index_non_tensor_unchanged():
"""Test that non-tensor index values remain unchanged."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
complementary_data = {
"index": 42, # Plain int, not tensor
@@ -994,7 +997,7 @@ def test_index_non_tensor_unchanged():
def test_index_dtype_preservation():
"""Test that index and task_index dtype is preserved during processing."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Test different dtypes
dtypes = [torch.int32, torch.int64, torch.long]
@@ -1017,7 +1020,7 @@ def test_index_dtype_preservation():
def test_index_with_full_transition():
"""Test index/task_index processing with full transition data."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create full transition with all components
observation = {
@@ -1059,7 +1062,7 @@ def test_index_with_full_transition():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_index_device_compatibility():
"""Test processor works with index/task_index tensors on different devices."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Create tensors on GPU
index_0d = torch.tensor(42, dtype=torch.int64, device="cuda")
@@ -1083,7 +1086,7 @@ def test_index_device_compatibility():
def test_empty_index_tensor():
"""Test handling of empty index tensors."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
# Empty 0D tensor doesn't make sense, but test empty 1D
index_empty = torch.tensor([], dtype=torch.int64)
@@ -1098,7 +1101,7 @@ def test_empty_index_tensor():
def test_action_processing_creates_new_transition():
"""Test that the processor creates a new transition object with correctly processed action."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
action = torch.randn(4)
transition = create_transition(action=action)
@@ -1120,7 +1123,7 @@ def test_action_processing_creates_new_transition():
def test_task_processing_creates_new_transition():
"""Test that the processor creates a new transition object with correctly processed task."""
processor = ToBatchProcessor()
processor = AddBatchDimensionProcessorStep()
complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data)
+18 -13
View File
@@ -24,8 +24,13 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import OBS_IMAGE, OBS_STATE
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
from lerobot.processor import DeviceProcessor, IdentityProcessor, NormalizerProcessor, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
from lerobot.processor import (
DataProcessorPipeline,
DeviceProcessorStep,
IdentityProcessorStep,
NormalizerProcessorStep,
TransitionKey,
)
def create_transition(observation=None, action=None, **kwargs):
@@ -82,14 +87,14 @@ def test_make_classifier_processor_basic():
# Check steps in preprocessor
assert len(preprocessor.steps) == 3
assert isinstance(preprocessor.steps[0], NormalizerProcessor) # For input features
assert isinstance(preprocessor.steps[1], NormalizerProcessor) # For output features
assert isinstance(preprocessor.steps[2], DeviceProcessor)
assert isinstance(preprocessor.steps[0], NormalizerProcessorStep) # For input features
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) # For output features
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], IdentityProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], IdentityProcessorStep)
def test_classifier_processor_normalization():
@@ -249,7 +254,7 @@ def test_classifier_processor_save_and_load():
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
preprocessor = DataProcessorPipeline(
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -258,7 +263,7 @@ def test_classifier_processor_save_and_load():
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -286,16 +291,16 @@ def test_classifier_processor_mixed_precision():
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessor):
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
# Create test data
observation = {
+120 -102
View File
@@ -2,112 +2,16 @@ import numpy as np
import pytest
import torch
from lerobot.processor import TransitionKey
from lerobot.processor.converters import (
to_dataset_frame,
to_output_robot_action,
batch_to_transition,
to_tensor,
to_transition_robot_observation,
to_transition_teleop_action,
transition_to_batch,
transition_to_dataset_frame,
)
from lerobot.processor.pipeline import TransitionKey
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
# Scalars, arrays, and uint8 arrays are all converted to tensors
img = np.zeros((8, 12, 3), dtype=np.uint8)
act = {
"ee.x": 0.5, # scalar to torch tensor
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
"raw_img": img, # uint8 HWC to torch tensor
}
tr = to_transition_teleop_action(act)
# Should be an EnvTransition-like dict with ACTION populated
assert isinstance(tr, dict)
assert TransitionKey.ACTION in tr
assert "action.ee.x" in tr[TransitionKey.ACTION]
assert "action.delta" in tr[TransitionKey.ACTION]
assert "action.raw_img" in tr[TransitionKey.ACTION]
# Types: all values -> torch tensor
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
assert isinstance(tr[TransitionKey.ACTION]["action.delta"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
# Observation is created as empty dict by make_transition
assert TransitionKey.OBSERVATION in tr
assert isinstance(tr[TransitionKey.OBSERVATION], dict)
assert tr[TransitionKey.OBSERVATION] == {}
def test_to_transition_robot_observation_state_vs_images_split():
# Create an observation with mixed content
img = np.full((10, 20, 3), 255, dtype=np.uint8) # image (uint8 HWC)
obs = {
"j1.pos": 10.0, # scalar to state to torch tensor
"j2.pos": np.float32(20.0), # scalar np to state to torch tensor
"image_front": img, # to images passthrough
"flag": np.int32(7), # scalar to state to torch tensor
"arr": np.array([1.5, 2.5]), # vector to state to torch tensor
}
tr = to_transition_robot_observation(obs)
assert isinstance(tr, dict)
assert TransitionKey.OBSERVATION in tr
out = tr[TransitionKey.OBSERVATION]
# Check state keys are present and converted to tensors
for k in ("j1.pos", "j2.pos", "flag", "arr"):
key = f"observation.state.{k}"
assert key in out
v = out[key]
if k != "arr":
assert isinstance(v, torch.Tensor) and v.ndim == 0
else:
assert isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape == (2,)
# Check image present as is
assert "observation.images.image_front" in out
assert isinstance(out["observation.images.image_front"], np.ndarray)
assert out["observation.images.image_front"].dtype == np.uint8
assert out["observation.images.image_front"].shape == (10, 20, 3)
# ACTION should be empty dict by make_transition
assert TransitionKey.ACTION in tr
assert isinstance(tr[TransitionKey.ACTION], dict)
assert tr[TransitionKey.ACTION] == {}
def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only():
# Build a transition with mixed action keys
tr = {
TransitionKey.ACTION: {
"action.j1.pos": 11.0, # keep "j1.pos"
"action.gripper.pos": torch.tensor(33.0), # keep: tensor accepted
"action.ee.x": 0.5, # ignore (doesn't end with .pos)
"misc": "ignore_me", # ignore (no 'action.' prefix)
}
}
out = to_output_robot_action(tr)
# Only ".pos" keys with "action." prefix are retained and stripped to base names
assert set(out.keys()) == {"j1.pos", "gripper.pos"}
# Values converted to float
assert isinstance(out["j1.pos"], float)
assert isinstance(out["gripper.pos"], float)
assert out["j1.pos"] == pytest.approx(11.0)
assert out["gripper.pos"] == pytest.approx(33.0)
def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
def test_transition_to_dataset_frame_merge_and_pack_vectors_and_metadata():
# Fabricate dataset features (as stored in dataset.meta["features"])
features = {
# Action vector: 3 elements in specific order
@@ -160,7 +64,7 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
}
# Directly call the refactored function
batch = to_dataset_frame([teleop_transition, robot_transition], features)
batch = transition_to_dataset_frame([teleop_transition, robot_transition], features)
# Images passthrough
assert "observation.images.front" in batch
@@ -377,3 +281,117 @@ def test_to_tensor_unsupported_type():
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor(object())
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def test_batch_to_transition_with_index_fields():
"""Test that batch_to_transition handles index and task_index fields correctly."""
# Create batch with index and task_index fields
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"next.reward": 1.5,
"next.done": False,
"task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
}
transition = batch_to_transition(batch)
# Check basic transition structure
assert TransitionKey.OBSERVATION in transition
assert TransitionKey.ACTION in transition
assert TransitionKey.COMPLEMENTARY_DATA in transition
# Check that index and task_index are in complementary_data
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
assert "index" in comp_data
assert "task_index" in comp_data
assert "task" in comp_data
# Verify values
assert torch.equal(comp_data["index"], batch["index"])
assert torch.equal(comp_data["task_index"], batch["task_index"])
assert comp_data["task"] == batch["task"]
def testtransition_to_batch_with_index_fields():
"""Test that transition_to_batch handles index and task_index fields correctly."""
# Create transition with index and task_index in complementary_data
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
reward=1.5,
done=False,
complementary_data={
"task": ["navigate"],
"index": torch.tensor([100], dtype=torch.int64),
"task_index": torch.tensor([5], dtype=torch.int64),
},
)
batch = transition_to_batch(transition)
# Check that index and task_index are in the batch
assert "index" in batch
assert "task_index" in batch
assert "task" in batch
# Verify values
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
def test_batch_to_transition_without_index_fields():
"""Test that conversion works without index and task_index fields."""
# Batch without index/task_index
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"task": ["pick_cube"],
}
transition = batch_to_transition(batch)
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
# Should have task but not index/task_index
assert "task" in comp_data
assert "index" not in comp_data
assert "task_index" not in comp_data
def test_transition_to_batch_without_index_fields():
"""Test that conversion works without index and task_index fields."""
# Transition without index/task_index
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]},
)
batch = transition_to_batch(transition)
# Should have task but not index/task_index
assert "task" in batch
assert "index" not in batch
assert "task_index" not in batch
+79 -74
View File
@@ -18,9 +18,8 @@ import tempfile
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor import DeviceProcessor, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
def create_transition(
@@ -47,7 +46,7 @@ def create_transition(
def test_basic_functionality():
"""Test basic device processor functionality on CPU."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
# Create a transition with CPU tensors
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
@@ -74,7 +73,7 @@ def test_basic_functionality():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_functionality():
"""Test device processor functionality on CUDA."""
processor = DeviceProcessor(device="cuda")
processor = DeviceProcessorStep(device="cuda")
# Create a transition with CPU tensors
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
@@ -101,7 +100,7 @@ def test_cuda_functionality():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_specific_cuda_device():
"""Test device processor with specific CUDA device."""
processor = DeviceProcessor(device="cuda:0")
processor = DeviceProcessorStep(device="cuda:0")
observation = {"observation.state": torch.randn(10)}
action = torch.randn(5)
@@ -117,7 +116,7 @@ def test_specific_cuda_device():
def test_non_tensor_values():
"""Test that non-tensor values are preserved."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
observation = {
"observation.state": torch.randn(10),
@@ -143,7 +142,7 @@ def test_non_tensor_values():
def test_none_values():
"""Test handling of None values."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
# Test with None observation
transition = create_transition(observation=None, action=torch.randn(5))
@@ -160,7 +159,7 @@ def test_none_values():
def test_empty_observation():
"""Test handling of empty observation dictionary."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
transition = create_transition(observation={}, action=torch.randn(5))
result = processor(transition)
@@ -171,7 +170,7 @@ def test_empty_observation():
def test_scalar_tensors():
"""Test handling of scalar tensors."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
observation = {"observation.scalar": torch.tensor(1.5)}
action = torch.tensor(2.0)
@@ -188,7 +187,7 @@ def test_scalar_tensors():
def test_dtype_preservation():
"""Test that tensor dtypes are preserved."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
observation = {
"observation.float32": torch.randn(5, dtype=torch.float32),
@@ -210,7 +209,7 @@ def test_dtype_preservation():
def test_shape_preservation():
"""Test that tensor shapes are preserved."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
observation = {
"observation.1d": torch.randn(10),
@@ -233,7 +232,7 @@ def test_shape_preservation():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mixed_devices():
"""Test handling of tensors already on different devices."""
processor = DeviceProcessor(device="cuda")
processor = DeviceProcessorStep(device="cuda")
# Create tensors on different devices
observation = {
@@ -254,22 +253,22 @@ def test_mixed_devices():
def test_non_blocking_flag():
"""Test that non_blocking flag is set correctly."""
# CPU processor should have non_blocking=False
cpu_processor = DeviceProcessor(device="cpu")
cpu_processor = DeviceProcessorStep(device="cpu")
assert cpu_processor.non_blocking is False
if torch.cuda.is_available():
# CUDA processor should have non_blocking=True
cuda_processor = DeviceProcessor(device="cuda")
cuda_processor = DeviceProcessorStep(device="cuda")
assert cuda_processor.non_blocking is True
cuda_0_processor = DeviceProcessor(device="cuda:0")
cuda_0_processor = DeviceProcessorStep(device="cuda:0")
assert cuda_0_processor.non_blocking is True
def test_serialization_methods():
"""Test get_config, state_dict, and load_state_dict methods."""
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = DeviceProcessor(device=device)
processor = DeviceProcessorStep(device=device)
# Test get_config
config = processor.get_config()
@@ -290,11 +289,13 @@ def test_serialization_methods():
def test_features():
"""Test that features returns features unchanged."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
PipelineFeatureType.OBSERVATION: {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
},
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
}
result = processor.transform_features(features)
@@ -305,13 +306,13 @@ def test_features():
def test_integration_with_robot_processor():
"""Test integration with RobotProcessor."""
from lerobot.constants import OBS_STATE
from lerobot.processor import ToBatchProcessor
from lerobot.processor import AddBatchDimensionProcessorStep
# Create a pipeline with DeviceProcessor
device_processor = DeviceProcessor(device="cpu")
batch_processor = ToBatchProcessor()
# Create a pipeline with DeviceProcessorStep
device_processor = DeviceProcessorStep(device="cpu")
batch_processor = AddBatchDimensionProcessorStep()
processor = RobotProcessor(
processor = DataProcessorPipeline(
steps=[batch_processor, device_processor],
name="test_pipeline",
to_transition=lambda x: x,
@@ -334,21 +335,21 @@ def test_integration_with_robot_processor():
def test_save_and_load_pretrained():
"""Test saving and loading processor with DeviceProcessor."""
"""Test saving and loading processor with DeviceProcessorStep."""
device = "cuda:0" if torch.cuda.is_available() else "cpu"
processor = DeviceProcessor(device=device, float_dtype="float16")
robot_processor = RobotProcessor(steps=[processor], name="device_test_processor")
processor = DeviceProcessorStep(device=device, float_dtype="float16")
robot_processor = DataProcessorPipeline(steps=[processor], name="device_test_processor")
with tempfile.TemporaryDirectory() as tmpdir:
# Save
robot_processor.save_pretrained(tmpdir)
# Load
loaded_processor = RobotProcessor.from_pretrained(tmpdir)
loaded_processor = DataProcessorPipeline.from_pretrained(tmpdir)
assert len(loaded_processor.steps) == 1
loaded_device_processor = loaded_processor.steps[0]
assert isinstance(loaded_device_processor, DeviceProcessor)
assert isinstance(loaded_device_processor, DeviceProcessorStep)
# Use getattr to access attributes safely
assert (
getattr(loaded_device_processor, "device", None) == device.split(":")[0]
@@ -357,18 +358,18 @@ def test_save_and_load_pretrained():
def test_registry_functionality():
"""Test that DeviceProcessor is properly registered."""
from lerobot.processor.pipeline import ProcessorStepRegistry
"""Test that DeviceProcessorStep is properly registered."""
from lerobot.processor import ProcessorStepRegistry
# Check that DeviceProcessor is registered
# Check that DeviceProcessorStep is registered
registered_class = ProcessorStepRegistry.get("device_processor")
assert registered_class is DeviceProcessor
assert registered_class is DeviceProcessorStep
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_performance_with_large_tensors():
"""Test performance with large tensors and non_blocking flag."""
processor = DeviceProcessor(device="cuda")
processor = DeviceProcessorStep(device="cuda")
# Create large tensors
observation = {
@@ -390,7 +391,7 @@ def test_performance_with_large_tensors():
def test_reward_done_truncated_types():
"""Test handling of different types for reward, done, and truncated."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
# Test with scalar values (not tensors)
transition = create_transition(
@@ -430,7 +431,7 @@ def test_reward_done_truncated_types():
def test_complementary_data_preserved():
"""Test that complementary_data is preserved unchanged."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
complementary_data = {
"task": "pick_object",
@@ -450,13 +451,13 @@ def test_complementary_data_preserved():
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick_object"
assert result[TransitionKey.COMPLEMENTARY_DATA]["episode_id"] == 42
assert result[TransitionKey.COMPLEMENTARY_DATA]["metadata"] == {"sensor": "camera_1"}
# Note: Currently DeviceProcessor doesn't process tensors in complementary_data
# Note: Currently DeviceProcessorStep doesn't process tensors in complementary_data
# This is intentional as complementary_data is typically metadata
def test_float_dtype_conversion():
"""Test float dtype conversion functionality."""
processor = DeviceProcessor(device="cpu", float_dtype="float16")
processor = DeviceProcessorStep(device="cpu", float_dtype="float16")
# Create tensors of different types
observation = {
@@ -486,7 +487,7 @@ def test_float_dtype_conversion():
def test_float_dtype_none():
"""Test that when float_dtype is None, no dtype conversion occurs."""
processor = DeviceProcessor(device="cpu", float_dtype=None)
processor = DeviceProcessorStep(device="cpu", float_dtype=None)
observation = {
"observation.float32": torch.randn(5, dtype=torch.float32),
@@ -507,7 +508,7 @@ def test_float_dtype_none():
def test_float_dtype_bfloat16():
"""Test conversion to bfloat16."""
processor = DeviceProcessor(device="cpu", float_dtype="bfloat16")
processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16")
observation = {"observation.state": torch.randn(5, dtype=torch.float32)}
action = torch.randn(3, dtype=torch.float64)
@@ -521,7 +522,7 @@ def test_float_dtype_bfloat16():
def test_float_dtype_float64():
"""Test conversion to float64."""
processor = DeviceProcessor(device="cpu", float_dtype="float64")
processor = DeviceProcessorStep(device="cpu", float_dtype="float64")
observation = {"observation.state": torch.randn(5, dtype=torch.float16)}
action = torch.randn(3, dtype=torch.float32)
@@ -536,27 +537,27 @@ def test_float_dtype_float64():
def test_float_dtype_invalid():
"""Test that invalid float_dtype raises ValueError."""
with pytest.raises(ValueError, match="Invalid float_dtype 'invalid_dtype'"):
DeviceProcessor(device="cpu", float_dtype="invalid_dtype")
DeviceProcessorStep(device="cpu", float_dtype="invalid_dtype")
def test_float_dtype_aliases():
"""Test that dtype aliases work correctly."""
# Test 'half' alias for float16
processor_half = DeviceProcessor(device="cpu", float_dtype="half")
processor_half = DeviceProcessorStep(device="cpu", float_dtype="half")
assert processor_half._target_float_dtype == torch.float16
# Test 'float' alias for float32
processor_float = DeviceProcessor(device="cpu", float_dtype="float")
processor_float = DeviceProcessorStep(device="cpu", float_dtype="float")
assert processor_float._target_float_dtype == torch.float32
# Test 'double' alias for float64
processor_double = DeviceProcessor(device="cpu", float_dtype="double")
processor_double = DeviceProcessorStep(device="cpu", float_dtype="double")
assert processor_double._target_float_dtype == torch.float64
def test_float_dtype_with_mixed_tensors():
"""Test float dtype conversion with mixed tensor types."""
processor = DeviceProcessor(device="cpu", float_dtype="float32")
processor = DeviceProcessorStep(device="cpu", float_dtype="float32")
observation = {
"observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
@@ -580,13 +581,13 @@ def test_float_dtype_with_mixed_tensors():
def test_float_dtype_serialization():
"""Test that float_dtype is properly serialized in get_config."""
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = DeviceProcessor(device=device, float_dtype="float16")
processor = DeviceProcessorStep(device=device, float_dtype="float16")
config = processor.get_config()
assert config == {"device": device, "float_dtype": "float16"}
# Test with None float_dtype
processor_none = DeviceProcessor(device="cpu", float_dtype=None)
processor_none = DeviceProcessorStep(device="cpu", float_dtype=None)
config_none = processor_none.get_config()
assert config_none == {"device": "cpu", "float_dtype": None}
@@ -595,7 +596,7 @@ def test_float_dtype_serialization():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_float_dtype_with_cuda():
"""Test float dtype conversion combined with CUDA device."""
processor = DeviceProcessor(device="cuda", float_dtype="float16")
processor = DeviceProcessorStep(device="cuda", float_dtype="float16")
# Create tensors on CPU with different dtypes
observation = {
@@ -620,7 +621,7 @@ def test_float_dtype_with_cuda():
def test_complementary_data_index_fields():
"""Test processing of index and task_index fields in complementary_data."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
# Create transition with index and task_index in complementary_data
complementary_data = {
@@ -658,7 +659,7 @@ def test_complementary_data_index_fields():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_complementary_data_index_fields_cuda():
"""Test moving index and task_index fields to CUDA."""
processor = DeviceProcessor(device="cuda:0")
processor = DeviceProcessorStep(device="cuda:0")
# Create CPU tensors
complementary_data = {
@@ -680,7 +681,7 @@ def test_complementary_data_index_fields_cuda():
def test_complementary_data_without_index_fields():
"""Test that complementary_data without index/task_index fields works correctly."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
complementary_data = {
"task": ["navigate"],
@@ -698,7 +699,7 @@ def test_complementary_data_without_index_fields():
def test_complementary_data_mixed_tensors():
"""Test complementary_data with mix of tensors and non-tensors."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
complementary_data = {
"task": ["pick_and_place"],
@@ -727,7 +728,7 @@ def test_complementary_data_mixed_tensors():
def test_complementary_data_float_dtype_conversion():
"""Test that float dtype conversion doesn't affect int tensors in complementary_data."""
processor = DeviceProcessor(device="cpu", float_dtype="float16")
processor = DeviceProcessorStep(device="cpu", float_dtype="float16")
complementary_data = {
"index": torch.tensor([42], dtype=torch.int64),
@@ -751,7 +752,7 @@ def test_complementary_data_float_dtype_conversion():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_complementary_data_full_pipeline_cuda():
"""Test full transition with complementary_data on CUDA."""
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
# Create full transition with mixed CPU tensors
observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)}
@@ -797,7 +798,7 @@ def test_complementary_data_full_pipeline_cuda():
def test_complementary_data_empty():
"""Test empty complementary_data handling."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
@@ -812,7 +813,7 @@ def test_complementary_data_empty():
def test_complementary_data_none():
"""Test None complementary_data handling."""
processor = DeviceProcessor(device="cpu")
processor = DeviceProcessorStep(device="cpu")
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
@@ -827,8 +828,8 @@ def test_complementary_data_none():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_preserves_gpu_placement():
"""Test that DeviceProcessor preserves GPU placement when tensor is already on GPU."""
processor = DeviceProcessor(device="cuda:0")
"""Test that DeviceProcessorStep preserves GPU placement when tensor is already on GPU."""
processor = DeviceProcessorStep(device="cuda:0")
# Create tensors already on GPU
observation = {
@@ -853,9 +854,9 @@ def test_preserves_gpu_placement():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_multi_gpu_preservation():
"""Test that DeviceProcessor preserves placement on different GPUs in multi-GPU setup."""
"""Test that DeviceProcessorStep preserves placement on different GPUs in multi-GPU setup."""
# Test 1: GPU-to-GPU preservation (cuda:0 config, cuda:1 input)
processor_gpu = DeviceProcessor(device="cuda:0")
processor_gpu = DeviceProcessorStep(device="cuda:0")
# Create tensors on cuda:1 (simulating Accelerate placement)
cuda1_device = torch.device("cuda:1")
@@ -874,7 +875,7 @@ def test_multi_gpu_preservation():
assert result[TransitionKey.ACTION].device == cuda1_device
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
processor_cpu = DeviceProcessor(device="cpu")
processor_cpu = DeviceProcessorStep(device="cpu")
transition_gpu = create_transition(
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda()
@@ -890,7 +891,7 @@ def test_multi_gpu_preservation():
def test_multi_gpu_with_cpu_tensors():
"""Test that CPU tensors are moved to configured device even in multi-GPU context."""
# Processor configured for cuda:1
processor = DeviceProcessor(device="cuda:1")
processor = DeviceProcessorStep(device="cuda:1")
# Mix of CPU and GPU tensors
observation = {
@@ -917,7 +918,7 @@ def test_multi_gpu_with_cpu_tensors():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_multi_gpu_with_float_dtype():
"""Test float dtype conversion works correctly with multi-GPU preservation."""
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
# Create float tensors on different GPUs
observation = {
@@ -947,7 +948,7 @@ def test_simulated_accelerate_scenario():
for gpu_id in range(min(torch.cuda.device_count(), 2)):
# Each "process" has a processor configured for cuda:0
# but data comes in already placed on the process's GPU
processor = DeviceProcessor(device="cuda:0")
processor = DeviceProcessorStep(device="cuda:0")
# Simulate data already placed by Accelerate
device = torch.device(f"cuda:{gpu_id}")
@@ -967,7 +968,11 @@ def test_policy_processor_integration():
"""Test integration with policy processors - input on GPU, output on CPU."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.processor import NormalizerProcessor, ToBatchProcessor, UnnormalizerProcessor
from lerobot.processor import (
AddBatchDimensionProcessorStep,
NormalizerProcessorStep,
UnnormalizerProcessorStep,
)
# Create features and stats
features = {
@@ -983,11 +988,11 @@ def test_policy_processor_integration():
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD}
# Create input processor (preprocessor) that moves to GPU
input_processor = RobotProcessor(
input_processor = DataProcessorPipeline(
steps=[
NormalizerProcessor(features=features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
DeviceProcessor(device="cuda"),
NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device="cuda"),
],
name="test_preprocessor",
to_transition=lambda x: x,
@@ -995,10 +1000,10 @@ def test_policy_processor_integration():
)
# Create output processor (postprocessor) that moves to CPU
output_processor = RobotProcessor(
output_processor = DataProcessorPipeline(
steps=[
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
],
name="test_postprocessor",
to_transition=lambda x: x,
+95 -21
View File
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
DeviceProcessorStep,
NormalizerProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(observation=None, action=None, **kwargs):
@@ -84,20 +84,20 @@ def test_make_diffusion_processor_basic():
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_diffusion_processor_with_images():
@@ -257,7 +257,7 @@ def test_diffusion_processor_save_and_load():
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
preprocessor = DataProcessorPipeline(
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -266,7 +266,7 @@ def test_diffusion_processor_save_and_load():
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -294,16 +294,27 @@ def test_diffusion_processor_mixed_precision():
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
# Replace DeviceProcessor with one that uses float16
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessor):
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
# Create test data
observation = {
@@ -379,3 +390,66 @@ def test_diffusion_processor_batch_consistency():
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch
assert processed[TransitionKey.ACTION].shape[0] == expected_batch
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_diffusion_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Get the steps from the factory function
factory_preprocessor, _ = make_diffusion_pre_post_processors(config, stats)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
# Create new processor with modified steps
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
# Verify initial normalizer configuration
normalizer_step = modified_steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(7, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+284 -76
View File
@@ -20,13 +20,15 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.converters import to_tensor
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
from lerobot.processor import (
DataProcessorPipeline,
IdentityProcessorStep,
NormalizerProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
hotswap_stats,
)
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
from lerobot.processor.converters import to_tensor
def create_transition(
@@ -123,7 +125,7 @@ def _create_observation_norm_map():
}
# Fixtures for observation normalisation tests using NormalizerProcessor
# Fixtures for observation normalisation tests using NormalizerProcessorStep
@pytest.fixture
def observation_stats():
return {
@@ -140,10 +142,10 @@ def observation_stats():
@pytest.fixture
def observation_normalizer(observation_stats):
"""Return a NormalizerProcessor that only has observation stats (no action)."""
"""Return a NormalizerProcessorStep that only has observation stats (no action)."""
features = _create_observation_features()
norm_map = _create_observation_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats)
def test_mean_std_normalization(observation_normalizer):
@@ -180,7 +182,7 @@ def test_min_max_normalization(observation_normalizer):
def test_selective_normalization(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(
normalizer = NormalizerProcessorStep(
features=features,
norm_map=norm_map,
stats=observation_stats,
@@ -206,7 +208,7 @@ def test_selective_normalization(observation_stats):
def test_device_compatibility(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
}
@@ -235,7 +237,7 @@ def test_from_lerobot_dataset():
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
# Both observation and action statistics should be present in tensor stats
assert "observation.image" in normalizer._tensor_stats
@@ -250,7 +252,7 @@ def test_state_dict_save_load(observation_normalizer):
# Create new normalizer and load state
features = _create_observation_features()
norm_map = _create_observation_norm_map()
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
new_normalizer.load_state_dict(state_dict)
# Test that it works the same
@@ -301,7 +303,7 @@ def _create_action_norm_map_min_max():
def test_mean_std_unnormalization(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
unnormalizer = UnnormalizerProcessorStep(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
@@ -319,7 +321,7 @@ def test_mean_std_unnormalization(action_stats_mean_std):
def test_min_max_unnormalization(action_stats_min_max):
features = _create_action_features()
norm_map = _create_action_norm_map_min_max()
unnormalizer = UnnormalizerProcessor(
unnormalizer = UnnormalizerProcessorStep(
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
)
@@ -345,7 +347,7 @@ def test_min_max_unnormalization(action_stats_min_max):
def test_numpy_action_input(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
unnormalizer = UnnormalizerProcessorStep(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
@@ -363,7 +365,7 @@ def test_numpy_action_input(action_stats_mean_std):
def test_none_action(action_stats_mean_std):
features = _create_action_features()
norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessor(
unnormalizer = UnnormalizerProcessorStep(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
@@ -379,11 +381,11 @@ def test_action_from_lerobot_dataset():
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
assert "mean" in unnormalizer._tensor_stats["action"]
# Fixtures for NormalizerProcessor tests
# Fixtures for NormalizerProcessorStep tests
@pytest.fixture
def full_stats():
return {
@@ -422,7 +424,7 @@ def _create_full_norm_map():
def normalizer_processor(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=full_stats)
def test_combined_normalization(normalizer_processor):
@@ -466,7 +468,7 @@ def test_processor_from_lerobot_dataset(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor.from_lerobot_dataset(
processor = NormalizerProcessorStep.from_lerobot_dataset(
mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"}
)
@@ -478,7 +480,7 @@ def test_processor_from_lerobot_dataset(full_stats):
def test_get_config(full_stats):
features = _create_full_features()
norm_map = _create_full_norm_map()
processor = NormalizerProcessor(
processor = NormalizerProcessorStep(
features=features,
norm_map=norm_map,
stats=full_stats,
@@ -506,7 +508,9 @@ def test_get_config(full_stats):
def test_integration_with_robot_processor(normalizer_processor):
"""Test integration with RobotProcessor pipeline"""
robot_processor = RobotProcessor([normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = DataProcessorPipeline(
[normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x
)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -535,7 +539,7 @@ def test_empty_observation():
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
transition = create_transition()
result = normalizer(transition)
@@ -546,7 +550,7 @@ def test_empty_observation():
def test_empty_stats():
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])}
transition = create_transition(observation=observation)
@@ -562,7 +566,7 @@ def test_partial_stats():
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])}
transition = create_transition(observation=observation)
@@ -577,7 +581,7 @@ def test_missing_action_stats_no_error():
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
# The tensor stats should not contain the 'action' key
assert "action" not in processor._tensor_stats
@@ -586,7 +590,7 @@ def test_serialization_roundtrip(full_stats):
"""Test that features and norm_map can be serialized and deserialized correctly."""
features = _create_full_features()
norm_map = _create_full_norm_map()
original_processor = NormalizerProcessor(
original_processor = NormalizerProcessorStep(
features=features,
norm_map=norm_map,
stats=full_stats,
@@ -598,7 +602,7 @@ def test_serialization_roundtrip(full_stats):
config = original_processor.get_config()
# Create a new processor from the config (deserialization)
new_processor = NormalizerProcessor(
new_processor = NormalizerProcessorStep(
features=config["features"],
norm_map=config["norm_map"],
stats=full_stats,
@@ -666,7 +670,7 @@ def test_identity_normalization_observations():
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -691,7 +695,7 @@ def test_identity_normalization_actions():
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
action = torch.tensor([1.0, -0.5])
transition = create_transition(action=action)
@@ -717,7 +721,7 @@ def test_identity_unnormalization_observations():
"observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -744,7 +748,7 @@ def test_identity_unnormalization_actions():
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
action = torch.tensor([0.5, -0.8]) # Normalized values
transition = create_transition(action=action)
@@ -767,8 +771,8 @@ def test_identity_with_missing_stats():
}
stats = {} # No stats provided
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([1.0, -0.5])
@@ -808,7 +812,7 @@ def test_identity_mixed_with_other_modes():
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -850,7 +854,7 @@ def test_identity_defaults_when_not_in_norm_map():
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -884,8 +888,8 @@ def test_identity_roundtrip():
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
original_action = torch.tensor([0.5, -0.2])
@@ -917,7 +921,7 @@ def test_identity_config_serialization():
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# Get config
config = normalizer.get_config()
@@ -927,7 +931,7 @@ def test_identity_config_serialization():
assert config["norm_map"]["ACTION"] == "MEAN_STD"
# Create new processor from config (simulating load)
new_normalizer = NormalizerProcessor(
new_normalizer = NormalizerProcessorStep(
features=config["features"],
norm_map=config["norm_map"],
stats=stats,
@@ -965,7 +969,7 @@ def test_identity_config_serialization():
# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# # Manually inject an invalid mode to test error handling
# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
@@ -1002,12 +1006,12 @@ def test_hotswap_stats_basic_functionality():
}
# Create processors
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
identity = IdentityProcessor()
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
identity = IdentityProcessorStep()
# Create robot processor
robot_processor = RobotProcessor(steps=[normalizer, unnormalizer, identity])
robot_processor = DataProcessorPipeline(steps=[normalizer, unnormalizer, identity])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
@@ -1043,8 +1047,8 @@ def test_hotswap_stats_deep_copy():
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = RobotProcessor(steps=[normalizer])
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = DataProcessorPipeline(steps=[normalizer])
# Store reference to original stats
original_stats_reference = original_processor.steps[0].stats
@@ -1068,7 +1072,7 @@ def test_hotswap_stats_deep_copy():
def test_hotswap_stats_only_affects_normalizer_steps():
"""Test that hotswap_stats only modifies NormalizerProcessor and UnnormalizerProcessor steps."""
"""Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps."""
stats = {
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
}
@@ -1083,11 +1087,11 @@ def test_hotswap_stats_only_affects_normalizer_steps():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
# Create mixed steps
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
identity = IdentityProcessor()
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
identity = IdentityProcessorStep()
robot_processor = RobotProcessor(steps=[normalizer, identity, unnormalizer])
robot_processor = DataProcessorPipeline(steps=[normalizer, identity, unnormalizer])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
@@ -1113,8 +1117,8 @@ def test_hotswap_stats_empty_stats():
}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = RobotProcessor(steps=[normalizer])
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = DataProcessorPipeline(steps=[normalizer])
# Hotswap with empty stats
new_processor = hotswap_stats(robot_processor, empty_stats)
@@ -1131,7 +1135,7 @@ def test_hotswap_stats_no_normalizer_steps():
}
# Create processor with only identity steps
robot_processor = RobotProcessor(steps=[IdentityProcessor(), IdentityProcessor()])
robot_processor = DataProcessorPipeline(steps=[IdentityProcessorStep(), IdentityProcessorStep()])
# Hotswap stats - should work without error
new_processor = hotswap_stats(robot_processor, stats)
@@ -1163,14 +1167,14 @@ def test_hotswap_stats_preserves_other_attributes():
normalize_observation_keys = {"observation.image"}
eps = 1e-6
normalizer = NormalizerProcessor(
normalizer = NormalizerProcessorStep(
features=features,
norm_map=norm_map,
stats=initial_stats,
normalize_observation_keys=normalize_observation_keys,
eps=eps,
)
robot_processor = RobotProcessor(steps=[normalizer])
robot_processor = DataProcessorPipeline(steps=[normalizer])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
@@ -1208,12 +1212,12 @@ def test_hotswap_stats_multiple_normalizer_types():
}
# Create multiple normalizers and unnormalizers
normalizer1 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
normalizer2 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer1 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer2 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer1 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
unnormalizer2 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = RobotProcessor(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
robot_processor = DataProcessorPipeline(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
@@ -1260,8 +1264,8 @@ def test_hotswap_stats_with_different_data_types():
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = RobotProcessor(steps=[normalizer])
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
robot_processor = DataProcessorPipeline(steps=[normalizer])
# Hotswap stats
new_processor = hotswap_stats(robot_processor, new_stats)
@@ -1316,8 +1320,10 @@ def test_hotswap_stats_functional_test():
}
# Create original processor
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = RobotProcessor(steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = DataProcessorPipeline(
steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x
)
# Process with original stats
original_result = original_processor(transition)
@@ -1360,7 +1366,7 @@ def test_zero_std_uses_eps():
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
observation = {"observation.state": torch.tensor([0.5])} # equals mean
out = normalizer(create_transition(observation=observation))
@@ -1372,7 +1378,7 @@ def test_min_equals_max_maps_to_minus_one():
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX}
stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
observation = {"observation.state": torch.tensor([2.0])}
out = normalizer(create_transition(observation=observation))
@@ -1387,7 +1393,7 @@ def test_action_normalized_despite_normalize_observation_keys():
}
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
normalizer = NormalizerProcessor(
normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"}
)
@@ -1405,12 +1411,12 @@ def test_unnormalize_observations_mean_std_and_min_max():
"observation.mm": PolicyFeature(FeatureType.STATE, (2,)),
}
# Build two processors: one mean/std and one min/max
unnorm_ms = UnnormalizerProcessor(
unnorm_ms = UnnormalizerProcessorStep(
features={"observation.ms": features["observation.ms"]},
norm_map={FeatureType.STATE: NormalizationMode.MEAN_STD},
stats={"observation.ms": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}},
)
unnorm_mm = UnnormalizerProcessor(
unnorm_mm = UnnormalizerProcessorStep(
features={"observation.mm": features["observation.mm"]},
norm_map={FeatureType.STATE: NormalizationMode.MIN_MAX},
stats={"observation.mm": {"min": np.array([0.0, -2.0]), "max": np.array([2.0, 2.0])}},
@@ -1432,7 +1438,7 @@ def test_unknown_observation_keys_ignored():
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])}
tr = create_transition(observation=obs)
@@ -1446,7 +1452,7 @@ def test_batched_action_normalization():
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1]
out = normalizer(create_transition(action=actions))[TransitionKey.ACTION]
@@ -1458,7 +1464,7 @@ def test_complementary_data_preservation():
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
comp = {"existing": 123}
tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp)
@@ -1477,8 +1483,8 @@ def test_roundtrip_normalize_unnormalize_non_identity():
"observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])},
"action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# Add a time dimension in action for broadcasting check (B,T,D)
obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])}
@@ -1491,3 +1497,205 @@ def test_roundtrip_normalize_unnormalize_non_identity():
out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5
)
assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5)
def test_dtype_adaptation_bfloat16_input_float32_normalizer():
"""Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output"""
features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {
"observation.state": {
"mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]),
"std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]),
}
}
# Create normalizer configured with float32 dtype
normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=stats, dtype=torch.float32
)
# Verify initial configuration
assert normalizer.dtype == torch.float32
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.float32
# Create bfloat16 input tensor
observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)}
transition = create_transition(observation=observation)
# Process the transition
result = normalizer(transition)
# Verify that:
# 1. Stats were automatically adapted to bfloat16
assert normalizer.dtype == torch.bfloat16
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.bfloat16
# 2. Output is in bfloat16
output_tensor = result[TransitionKey.OBSERVATION]["observation.state"]
assert output_tensor.dtype == torch.bfloat16
# 3. Normalization was applied correctly (mean should be close to original - mean) / std
expected = (
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)
- torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.bfloat16)
) / torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.bfloat16)
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
from lerobot.processor import DeviceProcessorStep
features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}}
# Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32)
device_processor = DeviceProcessorStep(device="cuda", float_dtype="bfloat16")
normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=stats, dtype=torch.float32
)
# Verify initial normalizer configuration
assert normalizer.dtype == torch.float32
# Create CPU input
observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)}
transition = create_transition(observation=observation)
# Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA
processed_1 = device_processor(transition)
intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"]
assert intermediate_tensor.dtype == torch.bfloat16
assert intermediate_tensor.device.type == "cuda"
# Step 2: NormalizerProcessor receives bfloat16 input and adapts
final_result = normalizer(processed_1)
final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"]
# Verify final output is bfloat16 (automatic adaptation worked)
assert final_tensor.dtype == torch.bfloat16
assert final_tensor.device.type == "cuda"
# Verify normalizer adapted its internal state
assert normalizer.dtype == torch.bfloat16
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.bfloat16
assert stat_tensor.device.type == "cuda"
def test_stats_reconstruction_after_load_state_dict():
"""
Test that stats dict is properly reconstructed from _tensor_stats after loading.
This test ensures the bug where stats became empty after loading is fixed.
The bug occurred when:
1. Only _tensor_stats were saved via state_dict()
2. stats field became empty {} after loading
3. Calling to() method or hotswap_stats would fail because they depend on self.stats
"""
# Create normalizer with stats
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
"action": {
"mean": np.array([0.0, 0.0]),
"std": np.array([1.0, 2.0]),
},
}
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# Save state dict (simulating save/load)
state_dict = original_normalizer.state_dict()
# Create new normalizer with empty stats (simulating load)
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
# Before fix: this would cause stats to remain empty
new_normalizer.load_state_dict(state_dict)
# Verify that stats dict is properly reconstructed from _tensor_stats
assert new_normalizer.stats is not None
assert new_normalizer.stats != {}
# Check that all expected keys are present
assert "observation.image" in new_normalizer.stats
assert "observation.state" in new_normalizer.stats
assert "action" in new_normalizer.stats
# Check that values are correct (converted back from tensors)
np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2])
np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0])
np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0])
np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0])
np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0])
# Test that methods that depend on self.stats work correctly after loading
# This would fail before the bug fix because self.stats was empty
# Test 1: to() method should work without crashing
try:
new_normalizer.to(device="cpu", dtype=torch.float32)
# If we reach here, the bug is fixed
except (KeyError, AttributeError) as e:
pytest.fail(f"to() method failed after loading state_dict: {e}")
# Test 2: hotswap_stats should work
new_stats = {
"observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]},
"observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]},
"action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]},
}
pipeline = DataProcessorPipeline([new_normalizer])
try:
new_pipeline = hotswap_stats(pipeline, new_stats)
# If we reach here, hotswap_stats worked correctly
assert new_pipeline.steps[0].stats == new_stats
except (KeyError, AttributeError) as e:
pytest.fail(f"hotswap_stats failed after loading state_dict: {e}")
# Test 3: The normalizer should work functionally the same as the original
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
original_result = original_normalizer(transition)
new_result = new_normalizer(transition)
# Results should be identical (within floating point precision)
torch.testing.assert_close(
original_result[TransitionKey.OBSERVATION]["observation.image"],
new_result[TransitionKey.OBSERVATION]["observation.image"],
)
torch.testing.assert_close(
original_result[TransitionKey.OBSERVATION]["observation.state"],
new_result[TransitionKey.OBSERVATION]["observation.state"],
)
torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION])
+113 -58
View File
@@ -18,10 +18,9 @@ import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType
from lerobot.configs.types import FeatureType, PipelineFeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import VanillaObservationProcessor
from lerobot.processor.pipeline import TransitionKey
from lerobot.processor import TransitionKey, VanillaObservationProcessorStep
from tests.conftest import assert_contract_is_typed
@@ -42,7 +41,7 @@ def create_transition(
def test_process_single_image():
"""Test processing a single image."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -68,7 +67,7 @@ def test_process_single_image():
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -91,7 +90,7 @@ def test_process_image_dict():
def test_process_batched_image():
"""Test processing already batched images."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
@@ -108,7 +107,7 @@ def test_process_batched_image():
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
@@ -121,7 +120,7 @@ def test_invalid_image_format():
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
@@ -134,7 +133,7 @@ def test_invalid_image_dtype():
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
@@ -149,7 +148,7 @@ def test_no_pixels_in_observation():
def test_none_observation():
"""Test processor with None observation."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
transition = create_transition()
result = processor(transition)
@@ -159,7 +158,7 @@ def test_none_observation():
def test_serialization_methods():
"""Test serialization methods."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Test get_config
config = processor.get_config()
@@ -178,7 +177,7 @@ def test_serialization_methods():
def test_process_environment_state():
"""Test processing environment_state."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
@@ -199,7 +198,7 @@ def test_process_environment_state():
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
@@ -220,7 +219,7 @@ def test_process_agent_pos():
def test_process_batched_states():
"""Test processing already batched states."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
@@ -238,7 +237,7 @@ def test_process_batched_states():
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
@@ -263,7 +262,7 @@ def test_process_both_states():
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
@@ -277,7 +276,7 @@ def test_no_states_in_observation():
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -314,7 +313,7 @@ def test_complete_observation_processing():
def test_image_only_processing():
"""Test processing observation with only images."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
@@ -329,7 +328,7 @@ def test_image_only_processing():
def test_state_only_processing():
"""Test processing observation with only states."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
@@ -344,7 +343,7 @@ def test_state_only_processing():
def test_empty_observation():
"""Test processing empty observation."""
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
observation = {}
transition = create_transition(observation=observation)
@@ -360,7 +359,7 @@ def test_equivalent_to_original_function():
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -387,7 +386,7 @@ def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -411,76 +410,132 @@ def test_equivalent_with_image_dict():
def test_image_processor_features_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
},
}
out = processor.transform_features(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
assert "pixels" not in out
assert out["keep"] == features["keep"]
assert (
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
== features[PipelineFeatureType.OBSERVATION]["pixels"]
)
assert "pixels" not in out[PipelineFeatureType.OBSERVATION]
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_image_processor_features_observation_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
},
}
out = processor.transform_features(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
assert "observation.pixels" not in out
assert out["keep"] == features["keep"]
assert (
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
== features[PipelineFeatureType.OBSERVATION]["observation.pixels"]
)
assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION]
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory):
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
PipelineFeatureType.OBSERVATION: {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
},
}
out = processor.transform_features(features.copy())
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
assert out["keep"] == features["keep"]
assert (
f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"]
== features[PipelineFeatureType.OBSERVATION]["pixels.front"]
)
assert (
f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"]
== features[PipelineFeatureType.OBSERVATION]["pixels.wrist"]
)
assert (
f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"]
== features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"]
)
assert (
"pixels.front" not in out[PipelineFeatureType.OBSERVATION]
and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION]
and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION]
)
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_state_processor_features_environment_and_agent_pos(policy_feature_factory):
processor = VanillaObservationProcessor()
processor = VanillaObservationProcessorStep()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
},
}
out = processor.transform_features(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert out["keep"] == features["keep"]
assert (
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
== features[PipelineFeatureType.OBSERVATION]["environment_state"]
)
assert (
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
== features[PipelineFeatureType.OBSERVATION]["agent_pos"]
)
assert (
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
)
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_state_processor_features_prefixed_inputs(policy_feature_factory):
proc = VanillaObservationProcessor()
proc = VanillaObservationProcessorStep()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
PipelineFeatureType.OBSERVATION: {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
},
}
out = proc.transform_features(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert (
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
)
assert (
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
== features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"]
)
assert (
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
)
assert_contract_is_typed(out)
+110 -24
View File
@@ -25,13 +25,31 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
ProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_transition(observation=None, action=None, **kwargs):
@@ -83,7 +101,7 @@ def test_make_pi0_processor_basic():
config = create_default_config()
stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
@@ -92,22 +110,22 @@ def test_make_pi0_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], Pi0NewLineProcessor)
# Step 4 would be TokenizerProcessor but it's mocked
assert isinstance(preprocessor.steps[5], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
# Step 3 would be TokenizerProcessorStep but it's mocked
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_pi0_newline_processor_single_task():
@@ -165,7 +183,7 @@ def test_pi0_processor_cuda():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessor:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -187,7 +205,7 @@ def test_pi0_processor_cuda():
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
@@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessor:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -242,7 +260,7 @@ def test_pi0_processor_accelerate_scenario():
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
@@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessor:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -298,7 +316,7 @@ def test_pi0_processor_multi_gpu():
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
@@ -329,7 +347,7 @@ def test_pi0_processor_without_stats():
config = create_default_config()
# Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
dataset_stats=None,
@@ -359,3 +377,71 @@ def test_pi0_newline_processor_state_dict():
# Test get_config
config = processor.get_config()
assert config == {}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
stats = create_default_stats()
config.device = "cuda"
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, _ = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
File diff suppressed because it is too large Load Diff
+85 -60
View File
@@ -19,8 +19,13 @@ from pathlib import Path
import numpy as np
import torch
from lerobot.configs.types import FeatureType
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
from lerobot.configs.types import FeatureType, PipelineFeatureType
from lerobot.processor import (
DataProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
)
from lerobot.processor.rename_processor import rename_stats
from tests.conftest import assert_contract_is_typed
@@ -46,7 +51,7 @@ def test_basic_renaming():
"old_key1": "new_key1",
"old_key2": "new_key2",
}
processor = RenameProcessor(rename_map=rename_map)
processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = {
"old_key1": torch.tensor([1.0, 2.0]),
@@ -74,7 +79,7 @@ def test_basic_renaming():
def test_empty_rename_map():
"""Test processor with empty rename map (should pass through unchanged)."""
processor = RenameProcessor(rename_map={})
processor = RenameObservationsProcessorStep(rename_map={})
observation = {
"key1": torch.tensor([1.0]),
@@ -93,7 +98,7 @@ def test_empty_rename_map():
def test_none_observation():
"""Test processor with None observation."""
processor = RenameProcessor(rename_map={"old": "new"})
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
transition = create_transition()
result = processor(transition)
@@ -108,7 +113,7 @@ def test_overlapping_rename():
"a": "b",
"b": "c", # This creates a potential conflict
}
processor = RenameProcessor(rename_map=rename_map)
processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = {
"a": 1,
@@ -133,7 +138,7 @@ def test_partial_rename():
"observation.state": "observation.proprio_state",
"pixels": "observation.image",
}
processor = RenameProcessor(rename_map=rename_map)
processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = {
"observation.state": torch.randn(10),
@@ -163,15 +168,15 @@ def test_get_config():
"old1": "new1",
"old2": "new2",
}
processor = RenameProcessor(rename_map=rename_map)
processor = RenameObservationsProcessorStep(rename_map=rename_map)
config = processor.get_config()
assert config == {"rename_map": rename_map}
def test_state_dict():
"""Test state dict (should be empty for RenameProcessor)."""
processor = RenameProcessor(rename_map={"old": "new"})
"""Test state dict (should be empty for RenameProcessorStep)."""
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
state = processor.state_dict()
assert state == {}
@@ -186,9 +191,9 @@ def test_integration_with_robot_processor():
"agent_pos": "observation.state",
"pixels": "observation.image",
}
rename_processor = RenameProcessor(rename_map=rename_map)
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
pipeline = RobotProcessor([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = DataProcessorPipeline([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
observation = {
"agent_pos": np.array([1.0, 2.0, 3.0]),
@@ -220,32 +225,34 @@ def test_save_and_load_pretrained():
"old_state": "observation.state",
"old_image": "observation.image",
}
processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
processor = RenameObservationsProcessorStep(rename_map=rename_map)
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
config_path = (
Path(tmp_dir) / "testrenameprocessorstep.json"
) # Based on name="TestRenameProcessorStep"
assert config_path.exists()
# No state files should be created for RenameProcessor
# No state files should be created for RenameProcessorStep
state_files = list(Path(tmp_dir).glob("*.safetensors"))
assert len(state_files) == 0
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(
loaded_pipeline = DataProcessorPipeline.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
assert loaded_pipeline.name == "TestRenameProcessor"
assert loaded_pipeline.name == "TestRenameProcessorStep"
assert len(loaded_pipeline) == 1
# Check that loaded processor works correctly
loaded_processor = loaded_pipeline.steps[0]
assert isinstance(loaded_processor, RenameProcessor)
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
assert loaded_processor.rename_map == rename_map
# Test functionality after loading
@@ -262,24 +269,24 @@ def test_save_and_load_pretrained():
def test_registry_functionality():
"""Test that RenameProcessor is properly registered."""
"""Test that RenameProcessorStep is properly registered."""
# Check that it's registered
assert "rename_processor" in ProcessorStepRegistry.list()
assert "rename_observations_processor" in ProcessorStepRegistry.list()
# Get from registry
retrieved_class = ProcessorStepRegistry.get("rename_processor")
assert retrieved_class is RenameProcessor
retrieved_class = ProcessorStepRegistry.get("rename_observations_processor")
assert retrieved_class is RenameObservationsProcessorStep
# Create instance from registry
instance = retrieved_class(rename_map={"old": "new"})
assert isinstance(instance, RenameProcessor)
assert isinstance(instance, RenameObservationsProcessorStep)
assert instance.rename_map == {"old": "new"}
def test_registry_based_save_load():
"""Test save/load using registry name instead of module path."""
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"})
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save and load
@@ -288,24 +295,24 @@ def test_registry_based_save_load():
# Verify config uses registry name
import json
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
with open(Path(tmp_dir) / "dataprocessorpipeline.json") as f: # Default name is "RobotProcessor"
config = json.load(f)
assert "registry_name" in config["steps"][0]
assert config["steps"][0]["registry_name"] == "rename_processor"
assert config["steps"][0]["registry_name"] == "rename_observations_processor"
assert "class" not in config["steps"][0] # Should use registry, not module path
# Load should work
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_pipeline = DataProcessorPipeline.from_pretrained(tmp_dir)
loaded_processor = loaded_pipeline.steps[0]
assert isinstance(loaded_processor, RenameProcessor)
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
def test_chained_rename_processors():
"""Test multiple RenameProcessors in a pipeline."""
"""Test multiple RenameProcessorSteps in a pipeline."""
# First processor: rename raw keys to intermediate format
processor1 = RenameProcessor(
processor1 = RenameObservationsProcessorStep(
rename_map={
"pos": "agent_position",
"img": "camera_image",
@@ -313,14 +320,16 @@ def test_chained_rename_processors():
)
# Second processor: rename to final format
processor2 = RenameProcessor(
processor2 = RenameObservationsProcessorStep(
rename_map={
"agent_position": "observation.state",
"camera_image": "observation.image",
}
)
pipeline = RobotProcessor([processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x)
pipeline = DataProcessorPipeline(
[processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x
)
observation = {
"pos": np.array([1.0, 2.0]),
@@ -356,7 +365,7 @@ def test_nested_observation_rename():
"observation.images.right": "observation.camera.right_view",
"observation.proprio": "observation.proprioception",
}
processor = RenameProcessor(rename_map=rename_map)
processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = {
"observation.images.left": torch.randn(3, 64, 64),
@@ -386,7 +395,7 @@ def test_nested_observation_rename():
def test_value_types_preserved():
"""Test that various value types are preserved during renaming."""
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
processor = RenameProcessor(rename_map=rename_map)
processor = RenameObservationsProcessorStep(rename_map=rename_map)
tensor_value = torch.randn(3, 3)
array_value = np.random.rand(2, 2)
@@ -414,59 +423,75 @@ def test_value_types_preserved():
def test_features_basic_renaming(policy_feature_factory):
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
processor = RenameObservationsProcessorStep(rename_map={"a": "x", "b": "y"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (2,)),
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
"c": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"a": policy_feature_factory(FeatureType.VISUAL, (2,)),
"b": policy_feature_factory(FeatureType.VISUAL, (3,)),
"c": policy_feature_factory(FeatureType.VISUAL, (1,)),
},
}
out = processor.transform_features(features.copy())
# Values preserved and typed
assert out["x"] == features["a"]
assert out["y"] == features["b"]
assert out["c"] == features["c"]
assert out[PipelineFeatureType.OBSERVATION]["x"] == features[PipelineFeatureType.OBSERVATION]["a"]
assert out[PipelineFeatureType.OBSERVATION]["y"] == features[PipelineFeatureType.OBSERVATION]["b"]
assert out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["c"]
assert_contract_is_typed(out)
# Input not mutated
assert set(features) == {"a", "b", "c"}
assert set(features[PipelineFeatureType.OBSERVATION]) == {"a", "b", "c"}
def test_features_overlapping_keys(policy_feature_factory):
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
processor = RenameObservationsProcessorStep(rename_map={"a": "b", "b": "c"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (1,)),
"b": policy_feature_factory(FeatureType.STATE, (2,)),
PipelineFeatureType.OBSERVATION: {
"a": policy_feature_factory(FeatureType.VISUAL, (1,)),
"b": policy_feature_factory(FeatureType.VISUAL, (2,)),
},
}
out = processor.transform_features(features)
assert set(out) == {"b", "c"}
assert out["b"] == features["a"] # 'a' renamed to'b'
assert out["c"] == features["b"] # 'b' renamed to 'c'
assert set(out[PipelineFeatureType.OBSERVATION]) == {"b", "c"}
assert (
out[PipelineFeatureType.OBSERVATION]["b"] == features[PipelineFeatureType.OBSERVATION]["a"]
) # 'a' renamed to'b'
assert (
out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["b"]
) # 'b' renamed to 'c'
assert_contract_is_typed(out)
def test_features_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameProcessor(
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameObservationsProcessorStep(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
)
pipeline = RobotProcessor([processor1, processor2])
pipeline = DataProcessorPipeline([processor1, processor2])
spec = {
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"pos": policy_feature_factory(FeatureType.VISUAL, (7,)),
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"extra": policy_feature_factory(FeatureType.VISUAL, (1,)),
},
}
out = pipeline.transform_features(initial_features=spec)
assert set(out) == {"observation.state", "observation.image", "extra"}
assert out["observation.state"] == spec["pos"]
assert out["observation.image"] == spec["img"]
assert out["extra"] == spec["extra"]
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
assert (
out[PipelineFeatureType.OBSERVATION]["observation.state"]
== spec[PipelineFeatureType.OBSERVATION]["pos"]
)
assert (
out[PipelineFeatureType.OBSERVATION]["observation.image"]
== spec[PipelineFeatureType.OBSERVATION]["img"]
)
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
assert_contract_is_typed(out)
+94 -22
View File
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
DeviceProcessorStep,
NormalizerProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(observation=None, action=None, **kwargs):
@@ -86,20 +86,20 @@ def test_make_sac_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_sac_processor_normalization_modes():
@@ -234,13 +234,13 @@ def test_sac_processor_without_stats():
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
preprocessor = DataProcessorPipeline(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
postprocessor = DataProcessorPipeline(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
@@ -277,7 +277,7 @@ def test_sac_processor_save_and_load():
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -306,10 +306,25 @@ def test_sac_processor_mixed_precision():
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
if isinstance(step, DeviceProcessor):
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
@@ -374,3 +389,60 @@ def test_sac_processor_edge_cases():
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
# When action is None, it may still be present with None value
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
action = torch.randn(5, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
+124 -26
View File
@@ -20,7 +20,7 @@ from unittest.mock import patch
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.processor_smolvla import (
@@ -28,13 +28,31 @@ from lerobot.policies.smolvla.processor_smolvla import (
make_smolvla_pre_post_processors,
)
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
ProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_transition(observation=None, action=None, **kwargs):
@@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic():
config = create_default_config()
stats = create_default_stats()
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
@@ -97,22 +117,22 @@ def test_make_smolvla_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], SmolVLANewLineProcessor)
# Step 4 would be TokenizerProcessor but it's mocked
assert isinstance(preprocessor.steps[5], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor)
# Step 3 would be TokenizerProcessorStep but it's mocked
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_smolvla_newline_processor_single_task():
@@ -170,7 +190,7 @@ def test_smolvla_processor_cuda():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessor:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -192,7 +212,9 @@ def test_smolvla_processor_cuda():
def transform_features(self, features):
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
@@ -225,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessor:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -247,7 +269,9 @@ def test_smolvla_processor_accelerate_scenario():
def transform_features(self, features):
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
@@ -281,7 +305,7 @@ def test_smolvla_processor_multi_gpu():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessor:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -303,7 +327,9 @@ def test_smolvla_processor_multi_gpu():
def transform_features(self, features):
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
@@ -334,7 +360,9 @@ def test_smolvla_processor_without_stats():
config = create_default_config()
# Mock the tokenizer processor
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
dataset_stats=None,
@@ -372,7 +400,77 @@ def test_smolvla_newline_processor_transform_features():
# Test transform_features
features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
}
result = processor.transform_features(features)
assert result == features # Should return unchanged
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_smolvla_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, _ = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5)
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(8, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+101 -22
View File
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
DeviceProcessorStep,
NormalizerProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(observation=None, action=None, **kwargs):
@@ -89,20 +89,20 @@ def test_make_tdmpc_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_tdmpc_processor_normalization():
@@ -251,13 +251,13 @@ def test_tdmpc_processor_without_stats():
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
preprocessor = DataProcessorPipeline(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
postprocessor = DataProcessorPipeline(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
@@ -297,7 +297,7 @@ def test_tdmpc_processor_save_and_load():
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -330,10 +330,25 @@ def test_tdmpc_processor_mixed_precision():
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
if isinstance(step, DeviceProcessor):
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data
observation = {
@@ -410,3 +425,67 @@ def test_tdmpc_processor_edge_cases():
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
assert OBS_STATE not in processed[TransitionKey.OBSERVATION]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_tdmpc_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(12, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+91 -70
View File
@@ -1,5 +1,5 @@
"""
Tests for the TokenizerProcessor class.
Tests for the TokenizerProcessorStep class.
"""
import tempfile
@@ -8,10 +8,9 @@ from unittest.mock import patch
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
from lerobot.processor.tokenizer_processor import TokenizerProcessor
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
from tests.utils import require_package
@@ -96,7 +95,7 @@ def test_basic_tokenization(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -128,7 +127,7 @@ def test_basic_tokenization_with_tokenizer_object():
"""Test basic string tokenization functionality using tokenizer object directly."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -162,7 +161,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -190,7 +189,7 @@ def test_custom_keys(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -216,7 +215,7 @@ def test_none_complementary_data(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data=None)
@@ -231,7 +230,7 @@ def test_missing_task_key(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data={"other_field": "some value"})
@@ -246,7 +245,7 @@ def test_none_task_value(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data={"task": None})
@@ -261,7 +260,7 @@ def test_unsupported_task_type(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
# Test with integer task
transition = create_transition(complementary_data={"task": 123})
@@ -280,7 +279,7 @@ def test_unsupported_task_type(mock_auto_tokenizer):
def test_no_tokenizer_error():
"""Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided."""
with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"):
TokenizerProcessor()
TokenizerProcessorStep()
@require_package("transformers")
@@ -291,7 +290,7 @@ def test_invalid_tokenizer_name_error():
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found")
with pytest.raises(Exception, match="Model not found"):
TokenizerProcessor(tokenizer_name="invalid-tokenizer")
TokenizerProcessorStep(tokenizer_name="invalid-tokenizer")
@require_package("transformers")
@@ -301,7 +300,7 @@ def test_get_config_with_tokenizer_name(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(
processor = TokenizerProcessorStep(
tokenizer_name="test-tokenizer",
max_length=256,
task_key="instruction",
@@ -328,7 +327,7 @@ def test_get_config_with_tokenizer_object():
"""Test configuration serialization when using tokenizer object."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(
processor = TokenizerProcessorStep(
tokenizer=mock_tokenizer,
max_length=256,
task_key="instruction",
@@ -358,7 +357,7 @@ def test_state_dict_methods(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
# Should return empty dict
state = processor.state_dict()
@@ -375,7 +374,7 @@ def test_reset_method(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
# Should not raise error
processor.reset()
@@ -388,8 +387,10 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = RobotProcessor([tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = DataProcessorPipeline(
[tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x
)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -423,18 +424,20 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
original_processor = TokenizerProcessor(
original_processor = TokenizerProcessorStep(
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
)
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = DataProcessorPipeline(
[original_processor], to_transition=lambda x: x, to_output=lambda x: x
)
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
robot_processor.save_pretrained(temp_dir)
# Load processor - tokenizer will be recreated from saved config
loaded_processor = RobotProcessor.from_pretrained(
loaded_processor = DataProcessorPipeline.from_pretrained(
temp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -456,16 +459,20 @@ def test_save_and_load_pretrained_with_tokenizer_object():
"""Test saving and loading processor with tokenizer object using overrides."""
mock_tokenizer = MockTokenizer(vocab_size=100)
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
original_processor = TokenizerProcessorStep(
tokenizer=mock_tokenizer, max_length=32, task_key="instruction"
)
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
robot_processor = DataProcessorPipeline(
[original_processor], to_transition=lambda x: x, to_output=lambda x: x
)
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
robot_processor.save_pretrained(temp_dir)
# Load processor with tokenizer override (since tokenizer object wasn't saved)
loaded_processor = RobotProcessor.from_pretrained(
loaded_processor = DataProcessorPipeline.from_pretrained(
temp_dir,
overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}},
to_transition=lambda x: x,
@@ -488,40 +495,44 @@ def test_save_and_load_pretrained_with_tokenizer_object():
@require_package("transformers")
def test_registry_functionality():
"""Test that the processor is properly registered."""
from lerobot.processor.pipeline import ProcessorStepRegistry
from lerobot.processor import ProcessorStepRegistry
# Check that the processor is registered
assert "tokenizer_processor" in ProcessorStepRegistry.list()
# Check that we can retrieve it
retrieved_class = ProcessorStepRegistry.get("tokenizer_processor")
assert retrieved_class is TokenizerProcessor
assert retrieved_class is TokenizerProcessorStep
@require_package("transformers")
def test_features_basic():
"""Test basic feature contract functionality."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
PipelineFeatureType.OBSERVATION: {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
},
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
}
output_features = processor.transform_features(input_features)
# Check that original features are preserved
assert "observation.state" in output_features
assert "action" in output_features
assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION]
assert "action" in output_features[PipelineFeatureType.ACTION]
# Check that tokenized features are added
assert f"{OBS_LANGUAGE}.tokens" in output_features
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION]
# Check feature properties
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][
f"{OBS_LANGUAGE}.attention_mask"
]
assert tokens_feature.type == FeatureType.LANGUAGE
assert tokens_feature.shape == (128,)
@@ -533,17 +544,19 @@ def test_features_basic():
def test_features_with_custom_max_length():
"""Test feature contract with custom max_length."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=64)
input_features = {}
input_features = {PipelineFeatureType.OBSERVATION: {}}
output_features = processor.transform_features(input_features)
# Check that features use correct max_length
assert f"{OBS_LANGUAGE}.tokens" in output_features
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION]
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][
f"{OBS_LANGUAGE}.attention_mask"
]
assert tokens_feature.shape == (64,)
assert attention_mask_feature.shape == (64,)
@@ -553,18 +566,22 @@ def test_features_with_custom_max_length():
def test_features_existing_features():
"""Test feature contract when tokenized features already exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=256)
input_features = {
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
PipelineFeatureType.OBSERVATION: {
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
}
}
output_features = processor.transform_features(input_features)
# Should not overwrite existing features
assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved
assert output_features[f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"].shape == (
100,
) # Original shape preserved
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
@require_package("transformers")
@@ -590,7 +607,7 @@ def test_tokenization_parameters(mock_auto_tokenizer):
tracking_tokenizer = TrackingMockTokenizer()
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
processor = TokenizerProcessor(
processor = TokenizerProcessorStep(
tokenizer_name="test-tokenizer",
max_length=16,
padding="longest",
@@ -622,7 +639,7 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -657,7 +674,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -685,7 +702,7 @@ def test_empty_string_task(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -709,7 +726,7 @@ def test_very_long_task(mock_auto_tokenizer):
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
long_task = " ".join(["word"] * 100) # Very long task
transition = create_transition(
@@ -759,7 +776,9 @@ def test_custom_padding_side(mock_auto_tokenizer):
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
# Test left padding
processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left")
processor_left = TokenizerProcessorStep(
tokenizer_name="test-tokenizer", max_length=10, padding_side="left"
)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -771,7 +790,9 @@ def test_custom_padding_side(mock_auto_tokenizer):
assert tracking_tokenizer.padding_side_calls[-1] == "left"
# Test right padding (default)
processor_right = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="right")
processor_right = TokenizerProcessorStep(
tokenizer_name="test-tokenizer", max_length=10, padding_side="right"
)
processor_right(transition)
@@ -782,7 +803,7 @@ def test_custom_padding_side(mock_auto_tokenizer):
def test_device_detection_cpu():
"""Test that tokenized tensors stay on CPU when other tensors are on CPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CPU tensors
observation = {"observation.state": torch.randn(10)} # CPU tensor
@@ -806,7 +827,7 @@ def test_device_detection_cpu():
def test_device_detection_cuda():
"""Test that tokenized tensors are moved to CUDA when other tensors are on CUDA."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CUDA tensors
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor
@@ -831,7 +852,7 @@ def test_device_detection_cuda():
def test_device_detection_multi_gpu():
"""Test that tokenized tensors match device in multi-GPU setup."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Test with tensors on cuda:1
device = torch.device("cuda:1")
@@ -855,7 +876,7 @@ def test_device_detection_multi_gpu():
def test_device_detection_no_tensors():
"""Test that tokenized tensors stay on CPU when no other tensors exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with no tensors
transition = create_transition(
@@ -877,7 +898,7 @@ def test_device_detection_no_tensors():
def test_device_detection_mixed_devices():
"""Test device detection when tensors are on different devices (uses first found)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
if torch.cuda.is_available():
# Create transition with mixed devices
@@ -905,7 +926,7 @@ def test_device_detection_mixed_devices():
def test_device_detection_from_action():
"""Test that device is detected from action tensor when no observation tensors exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with action on CUDA but no observation tensors
observation = {"metadata": {"key": "value"}} # No tensors in observation
@@ -928,7 +949,7 @@ def test_device_detection_from_action():
def test_device_detection_preserves_dtype():
"""Test that device detection doesn't affect dtype of tokenized tensors."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with float tensor (to test dtype isn't affected)
observation = {"observation.state": torch.randn(10, dtype=torch.float16)}
@@ -948,16 +969,16 @@ def test_device_detection_preserves_dtype():
@require_package("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_integration_with_device_processor(mock_auto_tokenizer):
"""Test that TokenizerProcessor works correctly with DeviceProcessor in pipeline."""
from lerobot.processor import DeviceProcessor
"""Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline."""
from lerobot.processor import DeviceProcessorStep
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Create pipeline with TokenizerProcessor then DeviceProcessor
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
device_processor = DeviceProcessor(device="cuda:0")
robot_processor = RobotProcessor(
# Create pipeline with TokenizerProcessorStep then DeviceProcessorStep
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
device_processor = DeviceProcessorStep(device="cuda:0")
robot_processor = DataProcessorPipeline(
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
)
@@ -970,7 +991,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
result = robot_processor(transition)
# All tensors should end up on CUDA (moved by DeviceProcessor)
# All tensors should end up on CUDA (moved by DeviceProcessorStep)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda"
@@ -986,7 +1007,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
def test_simulated_accelerate_scenario():
"""Test scenario simulating Accelerate with data already on GPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Simulate Accelerate scenario: batch already on GPU
device = torch.device("cuda:0")
+102 -22
View File
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
DeviceProcessorStep,
NormalizerProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(observation=None, action=None, **kwargs):
@@ -89,20 +89,20 @@ def test_make_vqbet_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessor)
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
assert isinstance(preprocessor.steps[3], DeviceProcessor)
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessor)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_vqbet_processor_with_images():
@@ -244,13 +244,13 @@ def test_vqbet_processor_without_stats():
factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
preprocessor = DataProcessorPipeline(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
postprocessor = DataProcessorPipeline(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
@@ -290,7 +290,7 @@ def test_vqbet_processor_save_and_load():
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
@@ -323,10 +323,25 @@ def test_vqbet_processor_mixed_precision():
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
if isinstance(step, DeviceProcessor):
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data
observation = {
@@ -405,3 +420,68 @@ def test_vqbet_processor_sequential_processing():
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
assert result[TransitionKey.ACTION].shape == (1, 7)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_vqbet_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(8, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+326
View File
@@ -0,0 +1,326 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from lerobot.robots.reachy2 import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Robot,
Reachy2RobotConfig,
)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = {
**REACHY2_NECK_JOINTS,
**REACHY2_ANTENNAS_JOINTS,
**REACHY2_R_ARM_JOINTS,
**REACHY2_L_ARM_JOINTS,
}
PARAMS = [
{}, # default config
{"with_mobile_base": False},
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
{"use_external_commands": True, "disable_torque_on_disconnect": True},
{"use_external_commands": True, "with_mobile_base": False, "with_neck": False},
{"disable_torque_on_disconnect": False},
{"max_relative_target": 5},
{"with_right_teleop_camera": False},
{"with_left_teleop_camera": False, "with_right_teleop_camera": False},
{"with_left_teleop_camera": False, "with_torso_camera": True},
]
def _make_reachy2_sdk_mock():
class JointSpy:
__slots__ = (
"present_position",
"_goal_position",
"_on_set",
)
def __init__(self, present_position=0.0, on_set=None):
self.present_position = present_position
self._goal_position = present_position
self._on_set = on_set
@property
def goal_position(self):
return self._goal_position
@goal_position.setter
def goal_position(self, v):
self._goal_position = v
if self._on_set:
self._on_set()
r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True
def _connect():
r.is_connected.return_value = True
def _disconnect():
r.is_connected.return_value = False
# Global counter of goal_position sets
r._goal_position_set_total = 0
def _on_any_goal_set():
r._goal_position_set_total += 1
# Mock joints with some dummy positions
joints = {
k: JointSpy(
present_position=float(i),
on_set=_on_any_goal_set,
)
for i, k in enumerate(REACHY2_JOINTS.values())
}
r.joints = joints
# Mock mobile base with some dummy odometry
r.mobile_base = MagicMock()
r.mobile_base.odometry = {
"x": 0.1,
"y": -0.2,
"theta": 21.3,
"vx": 0.001,
"vy": 0.002,
"vtheta": 0.0,
}
r.connect = MagicMock(side_effect=_connect)
r.disconnect = MagicMock(side_effect=_disconnect)
# Mock methods
r.turn_on = MagicMock()
r.reset_default_limits = MagicMock()
r.send_goal_positions = MagicMock()
r.turn_off_smoothly = MagicMock()
r.mobile_base.set_goal_speed = MagicMock()
r.mobile_base.send_speed_command = MagicMock()
return r
def _make_reachy2_camera_mock(*args, **kwargs):
cfg = args[0] if args else kwargs.get("config")
name = getattr(cfg, "name", kwargs.get("name", "cam"))
image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam"))
width = getattr(cfg, "width", kwargs.get("width", 640))
height = getattr(cfg, "height", kwargs.get("height", 480))
cam = MagicMock(name=f"Reachy2CameraMock:{name}")
cam.name = name
cam.image_type = image_type
cam.width = width
cam.height = height
cam.connect = MagicMock()
cam.disconnect = MagicMock()
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
return cam
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
def reachy2(request):
with (
patch(
"lerobot.robots.reachy2.robot_reachy2.ReachySDK",
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
),
patch(
"lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera",
side_effect=_make_reachy2_camera_mock,
),
):
overrides = request.param
cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides)
robot = Reachy2Robot(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(reachy2):
assert not reachy2.is_connected
reachy2.connect()
assert reachy2.is_connected
reachy2.reachy.turn_on.assert_called_once()
reachy2.reachy.reset_default_limits.assert_called_once()
reachy2.disconnect()
assert not reachy2.is_connected
if reachy2.config.disable_torque_on_disconnect:
reachy2.reachy.turn_off_smoothly.assert_called_once()
else:
reachy2.reachy.turn_off_smoothly.assert_not_called()
reachy2.reachy.disconnect.assert_called_once()
def test_get_joints_dict(reachy2):
reachy2.connect()
if reachy2.config.with_neck:
assert "neck_yaw.pos" in reachy2.joints_dict
assert "neck_pitch.pos" in reachy2.joints_dict
assert "neck_roll.pos" in reachy2.joints_dict
else:
assert "neck_yaw.pos" not in reachy2.joints_dict
assert "neck_pitch.pos" not in reachy2.joints_dict
assert "neck_roll.pos" not in reachy2.joints_dict
if reachy2.config.with_antennas:
assert "l_antenna.pos" in reachy2.joints_dict
assert "r_antenna.pos" in reachy2.joints_dict
else:
assert "l_antenna.pos" not in reachy2.joints_dict
assert "r_antenna.pos" not in reachy2.joints_dict
if reachy2.config.with_r_arm:
assert "r_shoulder_pitch.pos" in reachy2.joints_dict
assert "r_shoulder_roll.pos" in reachy2.joints_dict
assert "r_elbow_yaw.pos" in reachy2.joints_dict
assert "r_elbow_pitch.pos" in reachy2.joints_dict
assert "r_wrist_roll.pos" in reachy2.joints_dict
assert "r_wrist_pitch.pos" in reachy2.joints_dict
assert "r_wrist_yaw.pos" in reachy2.joints_dict
assert "r_gripper.pos" in reachy2.joints_dict
else:
assert "r_shoulder_pitch.pos" not in reachy2.joints_dict
assert "r_shoulder_roll.pos" not in reachy2.joints_dict
assert "r_elbow_yaw.pos" not in reachy2.joints_dict
assert "r_elbow_pitch.pos" not in reachy2.joints_dict
assert "r_wrist_roll.pos" not in reachy2.joints_dict
assert "r_wrist_pitch.pos" not in reachy2.joints_dict
assert "r_wrist_yaw.pos" not in reachy2.joints_dict
assert "r_gripper.pos" not in reachy2.joints_dict
if reachy2.config.with_l_arm:
assert "l_shoulder_pitch.pos" in reachy2.joints_dict
assert "l_shoulder_roll.pos" in reachy2.joints_dict
assert "l_elbow_yaw.pos" in reachy2.joints_dict
assert "l_elbow_pitch.pos" in reachy2.joints_dict
assert "l_wrist_roll.pos" in reachy2.joints_dict
assert "l_wrist_pitch.pos" in reachy2.joints_dict
assert "l_wrist_yaw.pos" in reachy2.joints_dict
assert "l_gripper.pos" in reachy2.joints_dict
else:
assert "l_shoulder_pitch.pos" not in reachy2.joints_dict
assert "l_shoulder_roll.pos" not in reachy2.joints_dict
assert "l_elbow_yaw.pos" not in reachy2.joints_dict
assert "l_elbow_pitch.pos" not in reachy2.joints_dict
assert "l_wrist_roll.pos" not in reachy2.joints_dict
assert "l_wrist_pitch.pos" not in reachy2.joints_dict
assert "l_wrist_yaw.pos" not in reachy2.joints_dict
assert "l_gripper.pos" not in reachy2.joints_dict
def test_get_observation(reachy2):
reachy2.connect()
obs = reachy2.get_observation()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(reachy2.cameras.keys())
assert set(obs.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
if reachy2.config.with_mobile_base:
for vel in REACHY2_VEL.keys():
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
if reachy2.config.with_left_teleop_camera:
assert obs["teleop_left"].shape == (
reachy2.config.cameras["teleop_left"].height,
reachy2.config.cameras["teleop_left"].width,
3,
)
if reachy2.config.with_right_teleop_camera:
assert obs["teleop_right"].shape == (
reachy2.config.cameras["teleop_right"].height,
reachy2.config.cameras["teleop_right"].width,
3,
)
if reachy2.config.with_torso_camera:
assert obs["torso_rgb"].shape == (
reachy2.config.cameras["torso_rgb"].height,
reachy2.config.cameras["torso_rgb"].width,
3,
)
def test_send_action(reachy2):
reachy2.connect()
action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)}
if reachy2.config.with_mobile_base:
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
previous_present_position = {
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
}
returned = reachy2.send_action(action)
if reachy2.config.max_relative_target is None:
assert returned == action
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
for motor in reachy2.joints_dict.keys():
expected_pos = action[motor]
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.max_relative_target is None:
assert real_pos == expected_pos
else:
assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min(
abs(expected_pos - real_pos), reachy2.config.max_relative_target
)
if reachy2.config.with_mobile_base:
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
if reachy2.config.use_external_commands:
reachy2.reachy.send_goal_positions.assert_not_called()
if reachy2.config.with_mobile_base:
reachy2.reachy.mobile_base.send_speed_command.assert_not_called()
else:
reachy2.reachy.send_goal_positions.assert_called_once()
if reachy2.config.with_mobile_base:
reachy2.reachy.mobile_base.send_speed_command.assert_called_once()
def test_no_part_declared():
with pytest.raises(ValueError):
_ = Reachy2RobotConfig(
ip_address="192.168.0.200",
with_mobile_base=False,
with_l_arm=False,
with_r_arm=False,
with_neck=False,
with_antennas=False,
)
@@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
REACHY2_NECK_JOINTS,
REACHY2_R_ARM_JOINTS,
REACHY2_VEL,
Reachy2Teleoperator,
Reachy2TeleoperatorConfig,
)
# {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = {
**REACHY2_NECK_JOINTS,
**REACHY2_ANTENNAS_JOINTS,
**REACHY2_R_ARM_JOINTS,
**REACHY2_L_ARM_JOINTS,
}
PARAMS = [
{}, # default config
{"with_mobile_base": False},
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
{"with_mobile_base": False, "with_neck": False},
{"use_present_position": True},
]
def _make_reachy2_sdk_mock():
r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True
def _connect():
r.is_connected.return_value = True
def _disconnect():
r.is_connected.return_value = False
# Mock joints with some dummy positions
joints = {
k: MagicMock(
present_position=float(i),
goal_position=float(i) + 0.5,
)
for i, k in enumerate(REACHY2_JOINTS.values())
}
r.joints = joints
# Mock mobile base with some dummy odometry
r.mobile_base = MagicMock()
r.mobile_base.last_cmd_vel = {
"vx": -0.2,
"vy": 0.2,
"vtheta": 11.0,
}
r.mobile_base.odometry = {
"x": 1.0,
"y": 2.0,
"theta": 20.0,
"vx": 0.1,
"vy": -0.1,
"vtheta": 8.0,
}
r.connect = MagicMock(side_effect=_connect)
r.disconnect = MagicMock(side_effect=_disconnect)
return r
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
def reachy2(request):
with (
patch(
"lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK",
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
),
):
overrides = request.param
cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides)
robot = Reachy2Teleoperator(cfg)
yield robot
if robot.is_connected:
robot.disconnect()
def test_connect_disconnect(reachy2):
assert not reachy2.is_connected
reachy2.connect()
assert reachy2.is_connected
reachy2.disconnect()
assert not reachy2.is_connected
reachy2.reachy.disconnect.assert_called_once()
def test_get_action(reachy2):
reachy2.connect()
action = reachy2.get_action()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
assert set(action.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
if reachy2.config.use_present_position:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
else:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.with_mobile_base:
if reachy2.config.use_present_position:
for vel in REACHY2_VEL.keys():
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
else:
for vel in REACHY2_VEL.keys():
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
def test_no_part_declared():
with pytest.raises(ValueError):
_ = Reachy2TeleoperatorConfig(
ip_address="192.168.0.200",
with_mobile_base=False,
with_l_arm=False,
with_r_arm=False,
with_neck=False,
with_antennas=False,
)
+8 -4
View File
@@ -5,7 +5,7 @@ from types import SimpleNamespace
import numpy as np
import pytest
from lerobot.processor.pipeline import TransitionKey
from lerobot.processor import TransitionKey
@pytest.fixture
@@ -86,7 +86,10 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
TransitionKey.ACTION: act,
}
vu.log_rerun_data(transition)
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
vu.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalar
@@ -141,7 +144,9 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
"vec": np.array([9, 8, 7], dtype=np.float32),
}
vu.log_rerun_data([obs_plain, act_plain])
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
vu.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes
expected = {
@@ -181,7 +186,6 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
vu, calls = mock_rerun
vu.log_rerun_data(
None,
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)