Compare commits

...

136 Commits

Author SHA1 Message Date
Adil Zouitine f247aa0701 refactor(tests): update processor test assertions to reflect new preprocessor and postprocessor names (#1869)
- Changed assertions in multiple processor test files to verify the updated names from "robot_preprocessor" and "robot_postprocessor" to "policy_preprocessor" and "policy_postprocessor" for consistency with recent refactoring.
2025-09-04 17:34:06 +02:00
Adil Zouitine 1ac6a6d3fe refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868)
- Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context.
- Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase.
2025-09-04 17:01:53 +02:00
Steven Palma e698c709d8 fix(deps): use in-house rotation utils over scipy throughout the codebase 2025-09-04 16:44:18 +02:00
Adil Zouitine a988da4789 feat(teleoperation): introduce HasTeleopEvents protocol and enhance teleop event handling (#1866)
- Added the HasTeleopEvents protocol to define a standard for teleoperators that provide control events.
- Implemented a runtime check to ensure teleoperators implement the get_teleop_events() method.
- Updated AddTeleopEventsAsInfoStep to utilize the new protocol, enhancing compatibility with custom teleoperators.
- Improved documentation for clarity on teleoperation event extraction and compatibility with built-in teleoperators.
2025-09-04 16:28:49 +02:00
Adil Zouitine 99963b6968 refactor(dependencies): remove scipy dependency and introduce custom rotation utilities (#1863)
- Removed the scipy dependency from the project to streamline requirements.
- Added a new `rotation.py` module containing a custom `Rotation` class that replicates essential functionalities of `scipy.spatial.transform.Rotation`, allowing for rotation vector, matrix, and quaternion conversions without external dependencies.
- Updated the `robot_kinematic_processor.py` to utilize the new custom rotation utilities.
2025-09-04 16:26:28 +02:00
Adil Zouitine 332ca4ccc5 refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)
- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity.
- Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests.
- Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
2025-09-04 16:22:03 +02:00
Adil Zouitine fc43246942 feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861)
- Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording.
- Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features.
2025-09-04 16:17:31 +02:00
Adil Zouitine 793ad86fc9 refactor(processor): enforce config_filename requirement for HF Hub loading (#1860)
- Updated the DataProcessorPipeline to require a specific config_filename when loading from Hugging Face Hub, enhancing clarity and preventing errors.
- Simplified local path checks and improved error handling for invalid paths.
- Adjusted tests to reflect the new requirement and ensure proper error handling for various loading scenarios.
2025-09-04 10:31:18 +02:00
Adil Zouitine a6dbb65917 chore(processor): add type alias RobotProcessorPipeline and PolicyProcessorPipeline (#1859)
* feat(processor): introduce PolicyProcessorPipeline and RobotProcessorPipeline as type aliases for DataProcessorPipeline

- Added PolicyProcessorPipeline and RobotProcessorPipeline type aliases to enhance clarity and maintainability in the processor module.
- Updated the __all__ list to include the new pipelines for better module export consistency.

* refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline across multiple modules

- Updated all instances of DataProcessorPipeline to PolicyProcessorPipeline in various processor files for consistency and clarity.
- Adjusted function signatures to reflect the new pipeline type, enhancing maintainability and readability.

* refactor(processor): update hotswap_stats function to use PolicyProcessorPipeline

- Changed the parameter name from robot_processor to policy_processor for clarity.
- Ensured consistency with recent updates to the processor module by reflecting the new pipeline type in the function signature.

* refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline in migrate_policy_normalization.py

- Updated the preprocessor and postprocessor to use PolicyProcessorPipeline for consistency with recent changes in the processor module.
- Enhanced clarity and maintainability by aligning with the new pipeline structure.

* refactor(processor): update hotswap_stats to use PolicyProcessorPipeline

- Changed the parameter type in hotswap_stats from DataProcessorPipeline to PolicyProcessorPipeline for consistency with recent updates.
- Enhanced clarity by updating the function documentation to reflect the new pipeline type.

* refactor(processor): replace DataProcessorPipeline with RobotProcessorPipeline across multiple files

- Updated instances of DataProcessorPipeline to RobotProcessorPipeline in evaluate.py, record.py, replay.py, teleoperate.py, and other relevant files for consistency and clarity.
- Adjusted function signatures and variable types to reflect the new pipeline structure, enhancing maintainability and readability.
2025-09-03 19:01:28 +02:00
Steven Palma 6c7169c4af chore(processor): rename teleop_phone variable names (#1858) 2025-09-03 18:42:13 +02:00
Adil Zouitine f125d5e3bf refactor(processor): rename internal device variable for clarity (#1857)
- Changed the internal device variable from `_device` to `tensor_device` for improved readability and consistency.
- Updated references throughout the class to reflect the new variable name.
2025-09-03 18:39:06 +02:00
Steven Palma 75dcfd4886 chore(processor): rename merge_features -> combine_feature_dicts (#1856) 2025-09-03 18:20:35 +02:00
Adil Zouitine ff3cbaa872 refactor(processor): rename internal tokenizer variable for clarity (#1855)
- Changed the internal tokenizer variable name from `_tokenizer` to `input_tokenizer` for improved readability and consistency.
- Updated references throughout the class to reflect the new variable name.
2025-09-03 18:20:12 +02:00
Adil Zouitine ce793cde64 chore(processor): add Step suffix to all processors (#1854)
* refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency

* refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules

* refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency

* refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency

* refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency

* refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency

* refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency

* refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency

* refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency

* refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency

* refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency

* refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency

* refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency

* refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency

* refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency

* refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency

* refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency

* refactor(processor): update config file name in test for RenameProcessorStep consistency
2025-09-03 18:12:11 +02:00
Steven Palma 029c4a9a76 chore(processor): rename converters function names (#1853)
* chore(processor): rename to_transition_teleop_action -> action_to_transition

* chore(processor): rename to_transition_robot_observation -> observation_to_transition

* chore(processor): rename to_output_robot_action -> transition_to_robot_action
2025-09-03 18:08:54 +02:00
Steven Palma d893bf1e30 chore(processor): rename specialized processor -> XYZProcessorStep (#1852) 2025-09-03 17:30:47 +02:00
Steven Palma 8c796b39f5 chore(processor): rename RobotProcessor -> DataProcessorPipeline (#1850) 2025-09-03 17:13:16 +02:00
Adil Zouitine 4ebe482a7e refactor(processors): enhance transform_features method across multiple processors (#1849)
* refactor(processors): enhance transform_features method across multiple processors

- Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features.
- Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others.
- Improved readability and maintainability by following consistent patterns in feature transformation.

* refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor

- Updated action and observation keys to use constants for improved readability and maintainability.
- Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys.
- Enhanced error handling by raising exceptions for missing required components in action and observation processing.
- Removed obsolete code and improved overall structure for better clarity.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(processors): remove unused import in joint_observations_processor

* refactor(processors): simplify transform_features method in delta_action_processor

* refactor(processors): streamline transform_features method in ImageCropResizeProcessor

* refactor(processors): improve error handling and streamline transform_features method in phone_processor

- Raised a ValueError for missing position and rotation in action to enhance error handling.

* refactor(processors): enhance error handling in JointVelocityProcessor

- Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method.

* refactor(processors): simplify transform_features method in robot kinematic processors

* refactor(processors): standardize action keys in phone_processor

* fix(processor): RKP feature obs -> act

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-09-03 16:54:41 +02:00
Steven Palma 2fcc358e98 refactor(processors): add extended api for specialized pipelines (#1848) 2025-09-03 12:28:40 +02:00
Steven Palma b052843f08 refactor(processors): unify import statements by consolidating pipeline imports into the main processor module (#1845) 2025-09-02 18:26:59 +02:00
Steven Palma ebb464c255 refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844) 2025-09-02 17:57:49 +02:00
Steven Palma 2914ae2a96 refactor(processors): add transform_features method to various processors (#1843) 2025-09-02 17:15:01 +02:00
Adil Zouitine 645c87e3a9 refactor(converters): gather converters and refactor the logic (#1833)
* refactor(converters): move batch transition functions to converters module

- Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns.
- Updated references in `RobotProcessor` to use the new location of these functions.
- Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields.
- Removed redundant tests from `pipeline.py` to streamline the test suite.

* refactor(processor): reorganize EnvTransition and TransitionKey definitions

- Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability.
- Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase.

* refactor(converters): rename and update dataset frame conversion functions

- Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming.
- Updated references in `record.py`, `pipeline.py`, and tests to use the new function name.
- Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability.
- Adjusted related tests to ensure correct functionality with the new naming conventions.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(processor): solve conflict artefacts

* refactor(converters): remove unused identity function and update type hints for merge_transitions

* refactor(processor): remove unused identity import and clean up gym_manipulator.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-09-02 15:33:38 +02:00
Steven Palma 2c802ac134 refactor(converters): implement unified tensor conversion function (#1841)
- Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors.
- Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability.
- Updated tests to cover new functionality and ensure correct tensor conversion behavior.

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
2025-09-02 13:47:04 +02:00
Steven Palma 15ffc01fb3 Revert "refactor(converters): implement unified tensor conversion function (#…" (#1840)
This reverts commit a837685bf8.
2025-09-02 13:43:35 +02:00
Adil Zouitine a837685bf8 refactor(converters): implement unified tensor conversion function (#1830)
- Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors.
- Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability.
- Updated tests to cover new functionality and ensure correct tensor conversion behavior.
2025-09-02 13:28:26 +02:00
Adil Zouitine d32b76cc66 refactor(processor): improve processor pipeline typing with generic type (#1810)
* refactor(processor): introduce generic type for to_output

- Always return `TOutput`
- Remove `_prepare_transition`, so `__call__` now always returns `TOutput`
- Update tests accordingly
- This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor

* refactor(processor): consolidate ProcessorKwargs usage across policies

- Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline.
- Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments.
- Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
2025-09-02 12:57:14 +02:00
Adil Zouitine 08fb310eaa refactor(constants, processor): standardize action and observation keys across multiple files (#1808)
- Added new constants for truncated and done states in constants.py.
- Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
2025-08-31 22:53:13 +02:00
Steven Palma 574a708950 Merge branch 'main' into user/azouitine/2025-7-4-convert-codebase-with-pipeline 2025-08-31 20:46:59 +02:00
Steven Palma ce665160ae feat(processor): multiple improvements to the pipeline porting (#1749)
* [Port codebase pipeline] General fixes for RL and scripts (#1748)

* Refactor dataset configuration in documentation and codebase

- Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency.
- Adjusted replay episode handling by renaming `episode` to `replay_episode`.
- Enhanced documentation
- added specific processor to transform from policy actions to delta actions

* Added Robot action to tensor processor
Added new processor script for dealing with gym specific action processing

* removed RobotAction2Tensor processor; imrpoved choosing observations in actor

* nit in delta action

* added missing reset functions to kinematics

* Adapt teleoperate and replay to pipeline similar to record

* refactor(processors): move to inheritance (#1750)

* fix(teleoperator): improvements phone implementation (#1752)

* fix(teleoperator): protect shared state in phone implementation

* refactor(teleop): separate classes in phone

* fix: solve breaking changes (#1753)

* refactor(policies): multiple improvements (#1754)

* refactor(processor): simpler logic in device processor (#1755)

* refactor(processor): euclidean distance in delta action processor (#1757)

* refactor(processor): improvements to joint observations processor migration (#1758)

* refactor(processor): improvements to tokenizer migration (#1759)

* refactor(processor): improvements to tokenizer migration

* fix(tests): tokenizer tests regression from #1750

* fix(processors): fix float comparison and config in hil processors (#1760)

* chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761)

* refactor(processor): improvements normalize pipeline migration (#1756)

* refactor(processor): several improvements normalize processor step

* refactor(processor): more improvements normalize processor

* refactor(processor): more changes to normalizer

* refactor(processor): take a different approach to DRY

* refactor(processor): final design

* chore(record): revert comment and continue deleted (#1764)

* refactor(examples): pipeline phone examples (#1769)

* refactor(examples): phone teleop + teleop script

* refactor(examples): phone replay + replay

* chore(examples): rename phone example files & folders

* feat(processor): fix improvements to the pipeline porting (#1796)

* refactor(processor): enhance tensor device handling in normalization process (#1795)

* refactor(tests): remove unsupported device detection test for complementary data (#1797)

* chore(tests): update ToBatchProcessor test (#1798)

* refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor

* test(tests): add tests for action and task processing in batch processor

* add names for android and ios phone (#1799)

* use _tensor_stats in normalize processor (#1800)

* fix(normalize_processor): correct device reference for tensor epsilon handling (#1801)

* add point 5 add missing feature contracts (#1806)

* Fix PR comments 1452 (#1807)

* use key to determine image

* Address rest of PR comments

* use PolicyFeatures in transform_features

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-08-31 20:38:52 +02:00
AdilZouitine 35c5d43255 chore(processor): Add default names for preprocessor and postprocessor in constants
- Introduced `PREPROCESSOR_DEFAULT_NAME` and `POSTPROCESSOR_DEFAULT_NAME` constants for consistent naming across various processor implementations.
- Updated processor creation in multiple policy files to utilize these constants, enhancing code readability and maintainability.
- Modified the training script to load and save the preprocessor and postprocessor using the new constants.
2025-08-11 18:00:25 +02:00
Steven Palma 95c1e32aa5 Merge branch 'main' into user/azouitine/2025-7-4-convert-codebase-with-pipeline 2025-08-11 13:56:03 +02:00
Michel Aractingi e4db65a127 Remove HILEnvConfig references 2025-08-11 11:14:57 +02:00
Michel Aractingi 0053defa2e Refactorgym_manipulator.py using the universal pipeline (#1650)
* Migrate gym_manipulator to use the pipeline
Added get_teleop_events function to capture relevant events from teleop devices unrelated to actions

* Added the capability to record a dataset

* Added the replay functionality with the pipeline

* Refactored `actor.py` to use the pipeline

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* RL works at this commit - fixed actor.py and bugs in gym_manipulator

* change folder structure to reduce the size of gym_manip

* Refactored hilserl config

* Remove dataset and mode from HilSerlEnvConfig to a GymManipulatorConfig to reduce verbose of configs during training

* format docs

* removed get_teleop_events from abc

* Refactor environment configuration and processing pipeline for GymHIL support. Removed device attribute from HILSerlRobotEnvConfig, added DummyTeleopDevice for simulation, and updated processor creation to accommodate GymHIL environments.

* Improved typing for HILRobotEnv config and GymManipulator config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Migrated `gym_manipulator` to use a more modular structure similar to phone teleop

* Refactor gripper handling and transition processing in HIL and robot kinematic processors

- Updated gripper position handling to use a consistent key format across processors
- Improved the EEReferenceAndDelta class to handle reference joint positions.
- Added support for discrete gripper actions in the GripperVelocityToJoint processor.
- Refactored the gym manipulator to improve modularity and clarity in processing steps.

* Added delta_action_processor mapping wrapper

* Added missing file delta_action_processor and improved imports in `gym_manipulator`

* nit

* Added missing file joint_observation_processor

* Enhance processing architecture with new teleoperation processors

- Introduced `AddTeleopActionAsComplimentaryData` and `AddTeleopEventsAsInfo` for integrating teleoperator actions and events into transitions.
- Added `Torch2NumpyActionProcessor` and `Numpy2TorchActionProcessor` for seamless conversion between PyTorch tensors and NumPy arrays.
- Updated `__init__.py` to include new processors in module exports, improving modularity and clarity in the processing pipeline.
- GymHIL is now fully supported with HIL using the pipeline

* Refactor configuration structure for gym_hil integration

- Renamed sections for better readability, such as changing "Gym Wrappers Configuration" to "Processor Configuration."
- Enhanced documentation with clear examples for dataset collection and policy evaluation configurations.

* Enhance reset configuration and teleoperation event handling

- Added `terminate_on_success` parameter to `ResetConfig` and `InterventionActionProcessor` for controlling episode termination behavior upon success detection.
- Updated documentation to clarify the impact of `terminate_on_success` on data collection for reward classifier training.
- Refactored teleoperation event handling to use `TeleopEvents` constants for improved readability and maintainability across various modules.

* fix(keyboard teleop), delta action keys

* Added transform features and feature contract

* Added transform features for image crop

* Enum for TeleopEvents

* Update tranform_features delta action proc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-08-11 11:07:55 +02:00
AdilZouitine fd5d8b3d5f refactor(train): Remove unnecessary tensor device handling in training loop 2025-08-08 19:35:15 +02:00
AdilZouitine 5bf82f8229 feat(tests): Add comprehensive tests for various policy processors
- Introduced new test files for ACT, Classifier, Diffusion, PI0, SAC, SmolVLA, TDMPC, and VQBeT policy processors.
- Each test file includes unit tests to validate functionality, including handling of batch sizes, device management, and data type conversions.
- Enhanced test coverage to ensure robustness and reliability of processor implementations across different scenarios.
2025-08-08 19:34:50 +02:00
AdilZouitine 5ca3920611 feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion
- Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios.
- Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions.
- Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios.
2025-08-08 19:33:24 +02:00
AdilZouitine 8bde9d0ab7 refactor(factory): streamline processor loading by removing unused comments
- Removed commented-out code related to loading pretrained processors in the make_processor function.
- This change enhances code clarity and maintains focus on the current implementation.
2025-08-08 13:23:26 +02:00
AdilZouitine abcbc16126 refactor(normalization): remove Normalize and Unnormalize classes
- Deleted the Normalize and Unnormalize classes from the normalization module to streamline the codebase.
- Updated tests to ensure compatibility with the removal of these classes, focusing on the new NormalizerProcessor and UnnormalizerProcessor implementations.
- Enhanced the handling of normalization statistics and improved overall code clarity.
2025-08-08 13:23:10 +02:00
AdilZouitine e4fd30a8d4 feat(policies): convert save_policy_to_safetensors with pipeline 2025-08-08 13:21:50 +02:00
Adil Zouitine 5f759b1637 feat(dependencies): Add scipy as a required dependency
- Included `scipy>=1.15.2` in the project dependencies to enhance functionality and support for scientific computing tasks.
2025-08-07 18:09:49 +02:00
Adil Zouitine 6a75b4761a refactor(TokenizerProcessor): improve dependency handling and observation management
- Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility.
- Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed.
- Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures.
- Added error handling for missing transformers library, providing clear guidance for users on installation requirements.
2025-08-07 17:07:20 +02:00
Pepijn e5ade5565d Integrate pipeline and add phone teleop (#1681)
* Add normalization processor and related components

- Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization.
- Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks.
- Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports.
- Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity.
- Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enhance processing architecture with new components

- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility.
- Updated `__init__.py` to include `RenameProcessor` in module exports.
- Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling.
- Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.

* chore (docs): add docstring for processor

* fix (test): test factory

* fix(test): policies

* Update tests/processor/test_observation_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* chore(test): add suggestion made by copilot regarding numpy test

* fix(test): import issue

* Refactor normalization components and update tests

- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity.
- Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`.
- Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes.
- Enhanced handling of missing statistics in normalization processes.

* chore (docstrin):Improve docstring for NormalizerProcessor

* feat (device processor): Implement device processor

* chore (batch handling): Enhance processing components with batch conversion utilities

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(test): linting issue

* chore (output format): improves output format

* chore (type): add typing for multiprocess envs

* feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.

* chore(normalization): addressing comments from copilot

* chore(learner): nit comment from copilot

* feat(pipeline): Enhance step_through method to support both tuple and dict inputs

* refactor(pipeline): Simplify observation and padding data handling in batch transitions

* Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions

* fix(ci): temporary fix on dataset deps version

* feat(processors): Introduce processors for various policy types

- Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`.
- Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps.
- Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity.
- Enhanced test coverage to validate the integration of new processors with existing policy configurations.

* refactor(learner): Remove normalization from cached image features retrieval

- Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls.
- This change enhances clarity and aligns with the recent updates to policy processors.

* refactor(policies): Remove unnormalization step from action predictions

- Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction.
- This change improves code clarity and aligns with recent updates to policy processors.

* feat(train): Integrate preprocessor into training pipeline

* refactor(train): Update preprocessor initialization to include dataset statistics

* refactor(policies): Enhance processor creation and add NaN detection hook

* refactor(train): Update memory pinning logic for mps compatibility

* feat: initial commit phone teleop

* ugly delta control

* use quaternion

* Refactor observation preprocessing to use a modular pipeline system

- Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations.
- Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline.
- Added tests for the new processing components and ensured they match the original functionality.
- Removed hardcoded logic in favor of a more flexible, composable architecture.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor observation processing and improve modularity

- Updated `ObservationProcessor` to enhance the modular design for processing observations.
- Cleaned up imports and improved code readability by removing unnecessary lines and comments.
- Ensured backward compatibility while integrating new processing components.
- Added tests to validate the functionality of the updated processing architecture.

* Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability.

* Refactor processing architecture to use RobotProcessor

- Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity.
- Introduced ProcessorStepRegistry for better management of processing steps.
- Updated relevant documentation and tests to reflect the new processing structure.
- Enhanced the save/load functionality to support the new processor design.
- Added a model card template for RobotProcessor to facilitate sharing and documentation.

* Add RobotProcessor tutorial to documentation

- Introduced a new tutorial on using RobotProcessor for preprocessing robot data.
- Added a section in the table of contents for easy navigation to the new tutorial.
- The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add normalization processor and related components

- Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization.
- Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks.
- Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports.
- Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity.
- Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enhance processing architecture with new components

- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility.
- Updated `__init__.py` to include `RenameProcessor` in module exports.
- Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling.
- Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.

* chore (docs): add docstring for processor

* fix (test): test factory

* fix(test): policies

* Update tests/processor/test_observation_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* chore(test): add suggestion made by copilot regarding numpy test

* fix(test): import issue

* Refactor normalization components and update tests

- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity.
- Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`.
- Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes.
- Enhanced handling of missing statistics in normalization processes.

* chore (docstrin):Improve docstring for NormalizerProcessor

* feat (device processor): Implement device processor

* chore (batch handling): Enhance processing components with batch conversion utilities

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(test): linting issue

* chore (output format): improves output format

* chore (type): add typing for multiprocess envs

* feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.

* chore(normalization): addressing comments from copilot

* chore(learner): nit comment from copilot

* feat(pipeline): Enhance step_through method to support both tuple and dict inputs

* refactor(pipeline): Simplify observation and padding data handling in batch transitions

* Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.

* refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling

- Introduced constants for observation keys to enhance readability.
- Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly.
- Updated the environment state and agent position assignments to use the new constants, improving maintainability.

* feat(pipeline): Add hook unregistration functionality and enhance documentation

- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management.
- Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks.
- Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.

* refactor(pipeline): Clarify hook behavior and improve documentation

- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.

* feat(pipeline): Add __repr__ method to RobotProcessor for improved readability

- Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed.
- Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values.
- Ensured that the representation handles long lists of steps with truncation for better readability.

* chore(pipeline): Move _CFG_NAME along other class member

* refactor(pipeline): Utilize get_safe_torch_device for device assignment

- Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling.
- This change enhances code readability and maintains consistency in device management across the RobotProcessor class.

* refactor(pipeline): Enhance state filename generation and profiling method

- Updated state filename generation to use the registry name when available, improving clarity in saved files.
- Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling.
- Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results.

* chore(doc): address pip install commant lerobot that not exist yet

* feat(pipeline): Enhance configuration filename handling and state file naming

- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default.
- Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved.
- Added automatic detection for configuration files when loading from a directory, with error handling for multiple files.
- Updated tests to validate new features, including custom filenames and automatic config detection.

* refactor(pipeline): Improve state file naming conventions for clarity and uniqueness

- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory.
- Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts.
- Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.

* docs(pipeline): Add clarification for repo name sanitization process

* feat(processors): Introduce processors for various policy types

- Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`.
- Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps.
- Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity.
- Enhanced test coverage to validate the integration of new processors with existing policy configurations.

* refactor(learner): Remove normalization from cached image features retrieval

- Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls.
- This change enhances clarity and aligns with the recent updates to policy processors.

* refactor(policies): Remove unnormalization step from action predictions

- Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction.
- This change improves code clarity and aligns with recent updates to policy processors.

* feat(train): Integrate preprocessor into training pipeline

* refactor(train): Update preprocessor initialization to include dataset statistics

* refactor(policies): Enhance processor creation and add NaN detection hook

* feat(record): Integrate RobotProcessor into recording loop and update policy handling

- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities.
- Updated the logic to reset both policy and processor when provided, ensuring proper state management.
- Modified action prediction to utilize the processor, improving the overall functionality of the recording process.
- Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.

* feat(migration): Add script for migrating policy models with normalization layers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models

- Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference.
- Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture.
- Refined the extraction and removal of normalization statistics and layers, streamlining the migration process.
- Improved error handling for missing mandatory configuration fields during model instantiation.

* feat(migrate): Add model card generation and saving to migration script

- Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags.
- Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility.
- Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub.

* feat(processor): Introduce ToBatchProcessor for handling observation batching

- Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing.
- Implemented functionality to add batch dimensions to state and image observations as needed.
- Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types.
- Ensured compatibility with existing transition keys and maintained the integrity of non-observation data.

* feat(processors): Add ToBatchProcessor to multiple policy processors

- Integrated ToBatchProcessor into various policy processors to handle observation batching.
- Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality.
- Ensured consistency across all processor implementations for improved data handling.

* refactor(factory): Remove unused imports and NaN detection hook from processor creation

* feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration

- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration.
- Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation.
- Refactored the loading of pretrained processors to utilize the new configuration options.

* refactor(factory): Clean up imports in factory.py

- Removed unused import of IdentityProcessor to streamline the code.

* feat(migrate): Extend load_model_from_hub to include train configuration

- Updated load_model_from_hub to return the train configuration alongside the model state_dict and config.
- Modified main function to handle the additional train configuration when loading models from both the hub and local paths.
- Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy.

* refactor(record): Rename processor parameters and update processing logic

- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity.
- Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow.
- Ensured compatibility with existing functionality while improving code readability.

* feat(batch_processor): Add task field processing to ToBatchProcessor

- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.

* feat(normalization): Implement IDENTITY mode for normalization and unnormalization

- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified.
- Updated processing logic to check normalization modes and handle missing statistics gracefully.
- Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios.
- Improved error handling for unsupported normalization modes.

* fix(rebase): remove residual normalization layer:

* refactor(diffusion): remove normalization layer from input processing

* Add debug + calib

* cleanup

* Add pipeline

* fix int

* Add record example

* nit

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops

* cleaned up steps and integrated pipeline with feature_contract

* refactor steps and robot to pipeline

* cleanup pipeline

* cleanup code further

* make it run

* feat(processors): Introduce processors for various policy types

- Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`.
- Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps.
- Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity.
- Enhanced test coverage to validate the integration of new processors with existing policy configurations.

* refactor(learner): Remove normalization from cached image features retrieval

- Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls.
- This change enhances clarity and aligns with the recent updates to policy processors.

* refactor(policies): Remove unnormalization step from action predictions

- Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction.
- This change improves code clarity and aligns with recent updates to policy processors.

* feat(train): Integrate preprocessor into training pipeline

* refactor(train): Update preprocessor initialization to include dataset statistics

* refactor(policies): Enhance processor creation and add NaN detection hook

* feat(record): Integrate RobotProcessor into recording loop and update policy handling

- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities.
- Updated the logic to reset both policy and processor when provided, ensuring proper state management.
- Modified action prediction to utilize the processor, improving the overall functionality of the recording process.
- Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.

* feat(migration): Add script for migrating policy models with normalization layers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models

- Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference.
- Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture.
- Refined the extraction and removal of normalization statistics and layers, streamlining the migration process.
- Improved error handling for missing mandatory configuration fields during model instantiation.

* feat(migrate): Add model card generation and saving to migration script

- Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags.
- Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility.
- Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub.

* feat(processor): Introduce ToBatchProcessor for handling observation batching

- Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing.
- Implemented functionality to add batch dimensions to state and image observations as needed.
- Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types.
- Ensured compatibility with existing transition keys and maintained the integrity of non-observation data.

* feat(processors): Add ToBatchProcessor to multiple policy processors

- Integrated ToBatchProcessor into various policy processors to handle observation batching.
- Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality.
- Ensured consistency across all processor implementations for improved data handling.

* refactor(factory): Remove unused imports and NaN detection hook from processor creation

* feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration

- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration.
- Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation.
- Refactored the loading of pretrained processors to utilize the new configuration options.

* refactor(factory): Clean up imports in factory.py

- Removed unused import of IdentityProcessor to streamline the code.

* feat(migrate): Extend load_model_from_hub to include train configuration

- Updated load_model_from_hub to return the train configuration alongside the model state_dict and config.
- Modified main function to handle the additional train configuration when loading models from both the hub and local paths.
- Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy.

* refactor(record): Rename processor parameters and update processing logic

- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity.
- Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow.
- Ensured compatibility with existing functionality while improving code readability.

* feat(batch_processor): Add task field processing to ToBatchProcessor

- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.

* feat(normalization): Implement IDENTITY mode for normalization and unnormalization

- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified.
- Updated processing logic to check normalization modes and handle missing statistics gracefully.
- Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios.
- Improved error handling for unsupported normalization modes.

* fix(rebase): remove residual normalization layer:

* refactor(diffusion): remove normalization layer from input processing

* refactor(normalization): Remove unused state dict transformation methods and streamline imports

- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process.
- Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging.
- Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation.
- Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.

* refactor(normalization): Clean up imports in normalize_processor.py

* feat(batch_processor): Add feature_contract method to ToBatchProcessor

- Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor.
- This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs.

* fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(tokenizer): Introduce TokenizerProcessor for text tokenization

- Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer.
- Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings.
- Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor.
- Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor.

* feat(language): Enhance language processing in TokenizerProcessor

- Added OBS_LANGUAGE constant to define the observation language key.
- Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature.
- Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization.
- Modified tests to validate the integration of language tokens and attention masks in the observation structure.

* feat(tokenizer): Add padding configuration to TokenizerProcessor

- Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction.
- Updated the `make_pi0_processor` function to include the new padding configuration.
- Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios.

* feat(processor): Add state management methods to Pi0NewLineProcessor

* feat(normalization): Track normalization and unnormalization info in complementary data

- Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes.
- Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions.
- Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys.

* feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs

- Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations.
- Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization.

* feat(processors): Integrate RenameProcessor into various processor configurations

- Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency.
- Updated the input steps to ensure compatibility with the new RenameProcessor integration.

* Do some todos and cleanup

* change feature_contract to dataset_features

* use one method for conversion pipeline output to add_frame dict and use base processors where possible

* Add back in and use record_loop

* update todo

* rename to_dataset_frame

* feat(smolvla): Refactor language processing and introduce new line processor (#1658)

- Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant.
- Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility.
- Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling.

* feat(processors): Integrate DeviceProcessor into multiple processor configurations

- Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines.
- Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix reference frame

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* feat(processor): Enhance DeviceProcessor with float dtype conversion

- Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types.
- Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype.
- Refactored tensor processing logic to streamline device movement and dtype conversion.
- Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios.

* update data visualization

* update teleop example

* fix record bugs

* Add replay

* Not code

* feature(pipeline): port tokenizer pipeline for VLA (#1645)

* feat(tokenizer): Introduce TokenizerProcessor for text tokenization

- Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer.
- Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings.
- Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor.
- Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor.

* feat(language): Enhance language processing in TokenizerProcessor

- Added OBS_LANGUAGE constant to define the observation language key.
- Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature.
- Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization.
- Modified tests to validate the integration of language tokens and attention masks in the observation structure.

* feat(tokenizer): Add padding configuration to TokenizerProcessor

- Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction.
- Updated the `make_pi0_processor` function to include the new padding configuration.
- Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios.

* feat(processor): Add state management methods to Pi0NewLineProcessor

* feat(normalization): Track normalization and unnormalization info in complementary data

- Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes.
- Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions.
- Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys.

* feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs

- Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations.
- Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization.

* feat(processors): Integrate RenameProcessor into various processor configurations

- Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency.
- Updated the input steps to ensure compatibility with the new RenameProcessor integration.

* feat(smolvla): Refactor language processing and introduce new line processor (#1658)

- Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant.
- Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility.
- Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling.

* feture(policies): add device processor (#1659)

* feat(processors): Integrate DeviceProcessor into multiple processor configurations

- Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines.
- Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* feat(processor): Enhance DeviceProcessor with float dtype conversion

- Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types.
- Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype.
- Refactored tensor processing logic to streamline device movement and dtype conversion.
- Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios.

* feat(policies): Add new line processors and update module exports

* feat(processor): Enhance batch and device processors to handle index and task_index fields

- Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors.
- Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged.
- Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation.

* Add eval script

* fix `q_curr` in InverseKinematicsEEToJoints to the IK solution

* feat(processors): Introduce processors for various policy types

- Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`.
- Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps.
- Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity.
- Enhanced test coverage to validate the integration of new processors with existing policy configurations.

* refactor(learner): Remove normalization from cached image features retrieval

- Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls.
- This change enhances clarity and aligns with the recent updates to policy processors.

* refactor(policies): Remove unnormalization step from action predictions

- Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction.
- This change improves code clarity and aligns with recent updates to policy processors.

* feat(train): Integrate preprocessor into training pipeline

* refactor(train): Update preprocessor initialization to include dataset statistics

* refactor(policies): Enhance processor creation and add NaN detection hook

* feat(record): Integrate RobotProcessor into recording loop and update policy handling

- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities.
- Updated the logic to reset both policy and processor when provided, ensuring proper state management.
- Modified action prediction to utilize the processor, improving the overall functionality of the recording process.
- Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.

* feat(migration): Add script for migrating policy models with normalization layers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models

- Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference.
- Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture.
- Refined the extraction and removal of normalization statistics and layers, streamlining the migration process.
- Improved error handling for missing mandatory configuration fields during model instantiation.

* feat(migrate): Add model card generation and saving to migration script

- Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags.
- Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility.
- Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub.

* feat(processor): Introduce ToBatchProcessor for handling observation batching

- Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing.
- Implemented functionality to add batch dimensions to state and image observations as needed.
- Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types.
- Ensured compatibility with existing transition keys and maintained the integrity of non-observation data.

* feat(processors): Add ToBatchProcessor to multiple policy processors

- Integrated ToBatchProcessor into various policy processors to handle observation batching.
- Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality.
- Ensured consistency across all processor implementations for improved data handling.

* refactor(factory): Remove unused imports and NaN detection hook from processor creation

* feat(batch_processor): Enhance ToBatchProcessor to handle action batching

- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration

- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration.
- Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation.
- Refactored the loading of pretrained processors to utilize the new configuration options.

* refactor(factory): Clean up imports in factory.py

- Removed unused import of IdentityProcessor to streamline the code.

* feat(migrate): Extend load_model_from_hub to include train configuration

- Updated load_model_from_hub to return the train configuration alongside the model state_dict and config.
- Modified main function to handle the additional train configuration when loading models from both the hub and local paths.
- Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy.

* refactor(record): Rename processor parameters and update processing logic

- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity.
- Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow.
- Ensured compatibility with existing functionality while improving code readability.

* feat(batch_processor): Add task field processing to ToBatchProcessor

- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.

* feat(normalization): Implement IDENTITY mode for normalization and unnormalization

- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified.
- Updated processing logic to check normalization modes and handle missing statistics gracefully.
- Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios.
- Improved error handling for unsupported normalization modes.

* fix(rebase): remove residual normalization layer:

* refactor(diffusion): remove normalization layer from input processing

* refactor(normalization): Remove unused state dict transformation methods and streamline imports

- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process.
- Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging.
- Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation.
- Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.

* refactor(normalization): Clean up imports in normalize_processor.py

* feat(batch_processor): Add feature_contract method to ToBatchProcessor

- Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor.
- This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs.

* fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feature(pipeline): port tokenizer pipeline for VLA (#1645)

* feat(tokenizer): Introduce TokenizerProcessor for text tokenization

- Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer.
- Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings.
- Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor.
- Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor.

* feat(language): Enhance language processing in TokenizerProcessor

- Added OBS_LANGUAGE constant to define the observation language key.
- Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature.
- Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization.
- Modified tests to validate the integration of language tokens and attention masks in the observation structure.

* feat(tokenizer): Add padding configuration to TokenizerProcessor

- Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction.
- Updated the `make_pi0_processor` function to include the new padding configuration.
- Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios.

* feat(processor): Add state management methods to Pi0NewLineProcessor

* feat(normalization): Track normalization and unnormalization info in complementary data

- Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes.
- Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions.
- Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys.

* feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs

- Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations.
- Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization.

* feat(processors): Integrate RenameProcessor into various processor configurations

- Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency.
- Updated the input steps to ensure compatibility with the new RenameProcessor integration.

* feat(smolvla): Refactor language processing and introduce new line processor (#1658)

- Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant.
- Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility.
- Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling.

* feture(policies): add device processor (#1659)

* feat(processors): Integrate DeviceProcessor into multiple processor configurations

- Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines.
- Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* feat(processor): Enhance DeviceProcessor with float dtype conversion

- Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types.
- Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype.
- Refactored tensor processing logic to streamline device movement and dtype conversion.
- Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios.

* feat(policies): Add new line processors and update module exports

* feat(processor): Enhance batch and device processors to handle index and task_index fields

- Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors.
- Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged.
- Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation.

* refactor(processors): Standardize processor naming conventions

- Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format.
- Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme.
- Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability.

* refactor(factory): Update processor configuration and type hints

- Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety.
- Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility.
- Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations.
- Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling.

* Fix eval and android gripper

* add some tests

* refactor(factory, pi0fast): Update processor function names and parameters

- Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency.
- Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions.

* fix(train.py) push postprocessor with preprocessor
- Add preprocesser policy overrides for device and rename_map
- Add rename_map to DatasetRecordConfig (record.py)

* Cleanup pr

* fix more git diff pr issues

* add path as type in save_pretrained

* small nit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename test file

* fix: make dataset_features/feature_contract is optional

* fix tests

* Encorperate pr feedback

* clean up record.py

* add ascii art, fix normal record

* remove merge issues

* fix merge

* remove features

* Add feedback PR

* fix last 4 tests

* remove features check

* rename to transform_features

* add transform_features

* fix lekiwi eval and update eval api example

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-08-07 16:13:34 +02:00
Adil Zouitine 0524551f52 refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure
- Introduced RenameProcessor in the preprocessor to handle renaming features.
- Combined input and output features in a single NormalizerProcessor for improved efficiency.
- Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor.
- Added DeviceProcessor to both preprocessor and postprocessor for better device management.
2025-08-07 11:04:15 +02:00
Steven Palma 862bc7ef85 Merge branch 'main' into user/azouitine/2025-7-4-convert-codebase-with-pipeline 2025-08-06 21:08:32 +02:00
Adil Zouitine d38792d6e5 test(tokenizer_processor): Add require_package decorator for transformers
- Introduced @require_package("transformers") decorator in multiple test functions to ensure the transformers package is available before running tests.
- This change enhances test reliability by preventing failures due to missing dependencies.
2025-08-06 19:22:23 +02:00
pre-commit-ci[bot] db3cf0158c [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 16:08:39 +00:00
Adil Zouitine 0535f2a59a refactor(device_processor): Update device handling and improve type hints
- Changed device attribute type from torch.device to str for better clarity.
- Introduced a private _device attribute to store the actual torch.device instance.
- Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments.
- Refactored device-related assertions in tests to use a consistent approach for device type verification.
2025-08-06 18:08:15 +02:00
Michel Aractingi 2805ae347c fix(train.py) push postprocessor with preprocessor
- Add preprocesser policy overrides for device and rename_map
- Add rename_map to DatasetRecordConfig (record.py)
2025-08-06 17:21:17 +02:00
Adil Zouitine 28ef6fcd14 refactor(factory, pi0fast): Update processor function names and parameters
- Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency.
- Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions.
2025-08-06 17:21:16 +02:00
Adil Zouitine 7fc7ec75bb refactor(factory): Update processor configuration and type hints
- Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety.
- Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility.
- Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations.
- Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling.
2025-08-06 17:21:15 +02:00
Adil Zouitine 87890cbf38 refactor(processors): Standardize processor naming conventions
- Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format.
- Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme.
- Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability.
2025-08-06 17:21:14 +02:00
Adil Zouitine 5326ffe77e feature(pipeline): port tokenizer pipeline for VLA (#1645)
* feat(tokenizer): Introduce TokenizerProcessor for text tokenization

- Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer.
- Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings.
- Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor.
- Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor.

* feat(language): Enhance language processing in TokenizerProcessor

- Added OBS_LANGUAGE constant to define the observation language key.
- Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature.
- Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization.
- Modified tests to validate the integration of language tokens and attention masks in the observation structure.

* feat(tokenizer): Add padding configuration to TokenizerProcessor

- Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction.
- Updated the `make_pi0_processor` function to include the new padding configuration.
- Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios.

* feat(processor): Add state management methods to Pi0NewLineProcessor

* feat(normalization): Track normalization and unnormalization info in complementary data

- Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes.
- Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions.
- Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys.

* feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs

- Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations.
- Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization.

* feat(processors): Integrate RenameProcessor into various processor configurations

- Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency.
- Updated the input steps to ensure compatibility with the new RenameProcessor integration.

* feat(smolvla): Refactor language processing and introduce new line processor (#1658)

- Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant.
- Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility.
- Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling.

* feture(policies): add device processor (#1659)

* feat(processors): Integrate DeviceProcessor into multiple processor configurations

- Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines.
- Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* feat(processor): Enhance DeviceProcessor with float dtype conversion

- Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types.
- Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype.
- Refactored tensor processing logic to streamline device movement and dtype conversion.
- Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios.

* feat(policies): Add new line processors and update module exports

* feat(processor): Enhance batch and device processors to handle index and task_index fields

- Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors.
- Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged.
- Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation.
2025-08-06 17:21:13 +02:00
pre-commit-ci[bot] a1734cf575 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 17:21:12 +02:00
Adil Zouitine 82f300e880 fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 2025-08-06 17:21:11 +02:00
Adil Zouitine 3e7c9d7afc feat(batch_processor): Add feature_contract method to ToBatchProcessor
- Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor.
- This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs.
2025-08-06 17:21:09 +02:00
Adil Zouitine e9cb779eab refactor(normalization): Clean up imports in normalize_processor.py 2025-08-06 17:21:08 +02:00
Adil Zouitine 8ff95be04c refactor(normalization): Remove unused state dict transformation methods and streamline imports
- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process.
- Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging.
- Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation.
- Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.
2025-08-06 17:21:07 +02:00
Adil Zouitine f02ce69df0 refactor(diffusion): remove normalization layer from input processing 2025-08-06 17:21:07 +02:00
Adil Zouitine 1feb7b5d88 fix(rebase): remove residual normalization layer: 2025-08-06 17:21:06 +02:00
Adil Zouitine fbe9009db2 feat(normalization): Implement IDENTITY mode for normalization and unnormalization
- Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified.
- Updated processing logic to check normalization modes and handle missing statistics gracefully.
- Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios.
- Improved error handling for unsupported normalization modes.
2025-08-06 17:21:05 +02:00
Adil Zouitine c0013b130b feat(batch_processor): Add task field processing to ToBatchProcessor
- Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference.
- Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings.
- Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data.
2025-08-06 17:21:04 +02:00
Adil Zouitine c4763f61a1 refactor(record): Rename processor parameters and update processing logic
- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity.
- Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow.
- Ensured compatibility with existing functionality while improving code readability.
2025-08-06 17:21:03 +02:00
Adil Zouitine b95c219d96 feat(migrate): Extend load_model_from_hub to include train configuration
- Updated load_model_from_hub to return the train configuration alongside the model state_dict and config.
- Modified main function to handle the additional train configuration when loading models from both the hub and local paths.
- Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy.
2025-08-06 17:21:02 +02:00
Adil Zouitine 9b1138171e refactor(factory): Clean up imports in factory.py
- Removed unused import of IdentityProcessor to streamline the code.
2025-08-06 17:21:02 +02:00
Adil Zouitine 023b8f3466 feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration
- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration.
- Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation.
- Refactored the loading of pretrained processors to utilize the new configuration options.
2025-08-06 17:21:00 +02:00
pre-commit-ci[bot] 1cad87ebd2 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 17:21:00 +02:00
Adil Zouitine 99de7567e6 feat(batch_processor): Enhance ToBatchProcessor to handle action batching
- Updated ToBatchProcessor to add batch dimensions to actions in addition to observations.
- Implemented separate methods for processing observations and actions, improving code readability.
- Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types.
2025-08-06 17:20:58 +02:00
Adil Zouitine 21baa8fa02 refactor(factory): Remove unused imports and NaN detection hook from processor creation 2025-08-06 17:20:53 +02:00
Adil Zouitine 8b4a5368b3 feat(processors): Add ToBatchProcessor to multiple policy processors
- Integrated ToBatchProcessor into various policy processors to handle observation batching.
- Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality.
- Ensured consistency across all processor implementations for improved data handling.
2025-08-06 17:20:52 +02:00
Adil Zouitine f5c6b03b61 feat(processor): Introduce ToBatchProcessor for handling observation batching
- Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing.
- Implemented functionality to add batch dimensions to state and image observations as needed.
- Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types.
- Ensured compatibility with existing transition keys and maintained the integrity of non-observation data.
2025-08-06 17:20:51 +02:00
Adil Zouitine e7be2fd113 feat(migrate): Add model card generation and saving to migration script
- Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags.
- Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility.
- Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub.
2025-08-06 17:20:50 +02:00
Adil Zouitine b632490b4b feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models
- Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference.
- Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture.
- Refined the extraction and removal of normalization statistics and layers, streamlining the migration process.
- Improved error handling for missing mandatory configuration fields during model instantiation.
2025-08-06 17:20:50 +02:00
pre-commit-ci[bot] 9a9c7208d2 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 17:20:49 +02:00
pre-commit-ci[bot] 427b97d198 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-06 17:20:48 +02:00
AdilZouitine 2c2bb1e8bf feat(migration): Add script for migrating policy models with normalization layers 2025-08-06 17:20:47 +02:00
AdilZouitine 4b24f94225 feat(record): Integrate RobotProcessor into recording loop and update policy handling
- Added support for RobotProcessor in the record_loop function to enhance data processing capabilities.
- Updated the logic to reset both policy and processor when provided, ensuring proper state management.
- Modified action prediction to utilize the processor, improving the overall functionality of the recording process.
- Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities.
2025-08-06 17:20:46 +02:00
AdilZouitine 670a278cbc refactor(policies): Enhance processor creation and add NaN detection hook 2025-08-06 17:20:45 +02:00
AdilZouitine fc74001202 refactor(train): Update preprocessor initialization to include dataset statistics 2025-08-06 17:20:45 +02:00
Adil Zouitine f14ac5d486 feat(train): Integrate preprocessor into training pipeline 2025-08-06 17:20:44 +02:00
Adil Zouitine 7bd0d62ce5 refactor(policies): Remove unnormalization step from action predictions
- Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction.
- This change improves code clarity and aligns with recent updates to policy processors.
2025-08-06 17:20:43 +02:00
Adil Zouitine 7eccefe235 refactor(learner): Remove normalization from cached image features retrieval
- Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls.
- This change enhances clarity and aligns with the recent updates to policy processors.
2025-08-06 17:20:42 +02:00
Adil Zouitine b72274066e feat(processors): Introduce processors for various policy types
- Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`.
- Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps.
- Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity.
- Enhanced test coverage to validate the integration of new processors with existing policy configurations.
2025-08-06 17:20:41 +02:00
Steven Palma 20f2910b63 Merge branch 'main' into user/azouitine/2025-7-2-implement-pipeline 2025-08-06 17:20:39 +02:00
Steven Palma fd4ae3466b refactor(pipeline): minor improvements (#1684)
* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes
2025-08-06 14:00:13 +02:00
Adil Zouitine 7beb040e8e refactor(pipeline): Rename parameters for clarity and enhance save/load functionality
- Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path.
- Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names.
- Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability.
2025-08-05 17:44:21 +02:00
Adil Zouitine 05bd18f453 refactor(observation): Streamline observation preprocessing and remove unused processor methods
- Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting.
- Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow.
- Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script.
2025-08-05 10:32:56 +02:00
Adil Zouitine 8077456c00 refactor(pipeline): Remove model card generation and streamline processor methods
- Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template.
- Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters.
- Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability.
2025-08-05 10:31:09 +02:00
AdilZouitine 5595887fd0 refactor(pipeline): Remove to() method for device management
- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.
2025-08-05 10:27:25 +02:00
Adil Zouitine 41959389b6 docs(pipeline): Clarify transition handling and hook behavior
- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats.
- Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change.
- Enhanced test assertions to verify the structure of results and the correctness of processing steps.
2025-08-02 14:51:52 +02:00
Pepijn 2c4e888c7f Feat/pipeline add feature contract (#1637)
* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops
2025-08-01 08:41:54 +02:00
Adil Zouitine 5ced72e6b8 docs(pipeline): Add clarification for repo name sanitization process 2025-08-01 08:41:54 +02:00
Adil Zouitine 907023f9f7 refactor(pipeline): Improve state file naming conventions for clarity and uniqueness
- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory.
- Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts.
- Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.
2025-08-01 08:41:54 +02:00
Adil Zouitine 4ba23ea029 feat(pipeline): Enhance configuration filename handling and state file naming
- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default.
- Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved.
- Added automatic detection for configuration files when loading from a directory, with error handling for multiple files.
- Updated tests to validate new features, including custom filenames and automatic config detection.
2025-08-01 08:41:54 +02:00
Adil Zouitine 409ac0baca chore(doc): address pip install commant lerobot that not exist yet 2025-08-01 08:41:54 +02:00
Adil Zouitine 699363f9fc refactor(pipeline): Enhance state filename generation and profiling method
- Updated state filename generation to use the registry name when available, improving clarity in saved files.
- Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling.
- Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results.
2025-08-01 08:41:54 +02:00
Adil Zouitine ae7a54de57 refactor(pipeline): Utilize get_safe_torch_device for device assignment
- Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling.
- This change enhances code readability and maintains consistency in device management across the RobotProcessor class.
2025-08-01 08:41:54 +02:00
Adil Zouitine fb9139b882 chore(pipeline): Move _CFG_NAME along other class member 2025-08-01 08:41:54 +02:00
Adil Zouitine 9fe3a3fb17 feat(pipeline): Add __repr__ method to RobotProcessor for improved readability
- Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed.
- Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values.
- Ensured that the representation handles long lists of steps with truncation for better readability.
2025-08-01 08:41:54 +02:00
Adil Zouitine 26cb9a24c3 refactor(pipeline): Clarify hook behavior and improve documentation
- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.
2025-08-01 08:41:54 +02:00
Adil Zouitine 77106697c3 feat(pipeline): Add hook unregistration functionality and enhance documentation
- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management.
- Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks.
- Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.
2025-08-01 08:41:54 +02:00
Adil Zouitine 75bc44c166 refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling
- Introduced constants for observation keys to enhance readability.
- Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly.
- Updated the environment state and agent position assignments to use the new constants, improving maintainability.
2025-08-01 08:41:54 +02:00
Adil Zouitine f2b79656eb refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
2025-08-01 08:41:53 +02:00
pre-commit-ci[bot] 14c2ece004 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-01 08:41:53 +02:00
Adil Zouitine 35612c61e1 refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions 2025-08-01 08:41:53 +02:00
pre-commit-ci[bot] f7bb3e2d90 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-01 08:41:53 +02:00
Adil Zouitine 1e0d667a22 Apply suggestions from code review
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-08-01 08:41:53 +02:00
Adil Zouitine 33969a0337 refactor(pipeline): Simplify observation and padding data handling in batch transitions 2025-08-01 08:41:53 +02:00
Adil Zouitine fa26290e8c feat(pipeline): Enhance step_through method to support both tuple and dict inputs 2025-08-01 08:41:53 +02:00
Adil Zouitine e9f7f5127b chore(learner): nit comment from copilot 2025-08-01 08:41:53 +02:00
Adil Zouitine 097842c70f chore(normalization): addressing comments from copilot 2025-08-01 08:41:53 +02:00
Adil Zouitine 3b8a3a32a0 feat (overrides): Implement support for loading processors with parameter overrides
- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.
2025-08-01 08:41:53 +02:00
Adil Zouitine 1c56779dd9 chore (type): add typing for multiprocess envs 2025-08-01 08:41:53 +02:00
Adil Zouitine 83a4338f8b chore (output format): improves output format 2025-08-01 08:41:53 +02:00
Adil Zouitine 730c7b2f35 fix(test): linting issue 2025-08-01 08:41:53 +02:00
pre-commit-ci[bot] 116059a43e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-01 08:41:53 +02:00
Adil Zouitine b08149a113 chore (batch handling): Enhance processing components with batch conversion utilities 2025-08-01 08:41:53 +02:00
Adil Zouitine c227107f60 feat (device processor): Implement device processor 2025-08-01 08:41:53 +02:00
Adil Zouitine 01dc289f3d chore (docstrin):Improve docstring for NormalizerProcessor 2025-08-01 08:41:53 +02:00
Adil Zouitine 6830ca7645 Refactor normalization components and update tests
- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity.
- Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`.
- Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes.
- Enhanced handling of missing statistics in normalization processes.
2025-08-01 08:41:52 +02:00
Adil Zouitine ed42c71fc3 fix(test): import issue 2025-08-01 08:41:52 +02:00
Adil Zouitine e0139065bd chore(test): add suggestion made by copilot regarding numpy test 2025-08-01 08:41:52 +02:00
Adil Zouitine e509f255af Update tests/processor/test_observation_processor.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-08-01 08:41:52 +02:00
Adil Zouitine e2fcd140b0 fix(test): policies 2025-08-01 08:41:52 +02:00
Adil Zouitine 2a7a0e6129 fix (test): test factory 2025-08-01 08:41:52 +02:00
Adil Zouitine 9f33791b19 chore (docs): add docstring for processor 2025-08-01 08:41:52 +02:00
Adil Zouitine 453e0a995f Enhance processing architecture with new components
- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility.
- Updated `__init__.py` to include `RenameProcessor` in module exports.
- Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling.
- Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.
2025-08-01 08:41:52 +02:00
pre-commit-ci[bot] 8ebf79c494 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-01 08:41:52 +02:00
Adil Zouitine 8774aec304 Add normalization processor and related components
- Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization.
- Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks.
- Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports.
- Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity.
- Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing.
2025-08-01 08:41:52 +02:00
pre-commit-ci[bot] ac742c9f0d [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-01 08:41:52 +02:00
Adil Zouitine cd13f1ecfd Add RobotProcessor tutorial to documentation
- Introduced a new tutorial on using RobotProcessor for preprocessing robot data.
- Added a section in the table of contents for easy navigation to the new tutorial.
- The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline.
2025-08-01 08:41:52 +02:00
Adil Zouitine 9aa632968f Refactor processing architecture to use RobotProcessor
- Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity.
- Introduced ProcessorStepRegistry for better management of processing steps.
- Updated relevant documentation and tests to reflect the new processing structure.
- Enhanced the save/load functionality to support the new processor design.
- Added a model card template for RobotProcessor to facilitate sharing and documentation.
2025-08-01 08:41:52 +02:00
Adil Zouitine 62caaf07b0 Remove redundant tests for None observation and serialization methods in test_observation_processor.py to streamline the test suite and improve maintainability. 2025-08-01 08:41:52 +02:00
Adil Zouitine 3355f04ca6 Refactor observation processing and improve modularity
- Updated `ObservationProcessor` to enhance the modular design for processing observations.
- Cleaned up imports and improved code readability by removing unnecessary lines and comments.
- Ensured backward compatibility while integrating new processing components.
- Added tests to validate the functionality of the updated processing architecture.
2025-08-01 08:41:52 +02:00
pre-commit-ci[bot] 769f531603 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-08-01 08:41:51 +02:00
Adil Zouitine f6c7287ae7 Refactor observation preprocessing to use a modular pipeline system
- Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations.
- Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline.
- Added tests for the new processing components and ensured they match the original functionality.
- Removed hardcoded logic in favor of a more flexible, composable architecture.
2025-08-01 08:41:51 +02:00
117 changed files with 15847 additions and 4735 deletions
+382 -56
View File
@@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
It combines three key ingredients:
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
@@ -56,30 +62,243 @@ pip install -e ".[hilserl]"
### Understanding Configuration
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
<!-- prettier-ignore-start -->
```python
class GymManipulatorConfig:
env: HILSerlRobotEnvConfig # Environment configuration (nested)
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
mode: str | None = None # "record", "replay", or None (for training)
device: str = "cpu" # Compute device
class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
fps: int = 10 # Control frequency
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
name: str = "real_robot" # Environment name
mode: str = None # "record", "replay", or None (for training)
repo_id: str | None = None # LeRobot dataset repository ID
dataset_root: str | None = None # Local dataset root (optional)
task: str = "" # Task identifier
num_episodes: int = 10 # Number of episodes for recording
episode: int = 0 # episode index for replay
device: str = "cuda" # Compute device
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
pretrained_policy_name_or_path: str | None = None # For policy loading
reward_classifier_pretrained_path: str | None = None # For reward model
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
task: str | None = None # Task identifier
fps: int = 10 # Control frequency
# Nested processor configuration
class HILSerlProcessorConfig:
control_mode: str = "gamepad" # Control mode
observation: ObservationConfig | None = None # Observation processing settings
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
gripper: GripperConfig | None = None # Gripper control and penalty settings
reset: ResetConfig | None = None # Environment reset and timing settings
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
max_gripper_pos: float | None = 100.0 # Maximum gripper position
# Sub-configuration classes
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
resize_size: tuple[int, int] | None = None # Target image size
class GripperConfig:
use_gripper: bool = True # Enable gripper control
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
class ResetConfig:
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
reset_time_s: float = 5.0 # Time to wait during reset
control_time_s: float = 20.0 # Maximum episode duration
terminate_on_success: bool = True # Whether to terminate episodes on success detection
class InverseKinematicsConfig:
urdf_path: str | None = None # Path to robot URDF file
target_frame_name: str | None = None # End-effector frame name
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
class RewardClassifierConfig:
pretrained_path: str | None = None # Path to pretrained reward classifier
success_threshold: float = 0.5 # Success detection threshold
success_reward: float = 1.0 # Reward value for successful episodes
# Dataset configuration
class DatasetConfig:
repo_id: str # LeRobot dataset repository ID
task: str # Task identifier
root: str | None = None # Local dataset root directory
num_episodes_to_record: int = 5 # Number of episodes for recording
replay_episode: int | None = None # Episode index for replay
push_to_hub: bool = False # Whether to push datasets to Hub
```
<!-- prettier-ignore-end -->
### Processor Pipeline Architecture
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
#### Environment Processor Pipeline
The environment processor (`env_processor`) handles incoming observations and environment state:
1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
#### Action Processor Pipeline
The action processor (`action_processor`) handles outgoing actions and human interventions:
1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
4. **InterventionActionProcessorStep**: Handles human interventions and episode termination
5. **Inverse Kinematics Pipeline** (when enabled):
- **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
- **EEBoundsAndSafety**: Enforces workspace safety bounds
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
- **GripperVelocityToJoint**: Handles gripper control commands
#### Configuration Examples
**Basic Observation Processing**:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": false,
"display_cameras": false
}
}
}
}
```
**Image Processing**:
```json
{
"env": {
"processor": {
"image_preprocessing": {
"crop_params_dict": {
"observation.images.front": [180, 250, 120, 150],
"observation.images.side": [180, 207, 180, 200]
},
"resize_size": [128, 128]
}
}
}
}
```
**Inverse Kinematics Setup**:
```json
{
"env": {
"processor": {
"inverse_kinematics": {
"urdf_path": "path/to/robot.urdf",
"target_frame_name": "end_effector",
"end_effector_bounds": {
"min": [0.16, -0.08, 0.03],
"max": [0.24, 0.2, 0.1]
},
"end_effector_step_sizes": {
"x": 0.02,
"y": 0.02,
"z": 0.02
}
}
}
}
}
```
### Advanced Observation Processing
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
#### Joint Velocity Processing
Enable joint velocity estimation to provide the policy with motion information:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true
}
}
}
}
```
This processor:
- Estimates joint velocities using finite differences between consecutive joint position readings
- Adds velocity information to the observation state vector
- Useful for policies that need motion awareness for dynamic tasks
#### Motor Current Processing
Monitor motor currents to detect contact forces and load conditions:
```json
{
"env": {
"processor": {
"observation": {
"add_current_to_observation": true
}
}
}
}
```
This processor:
- Reads motor current values from the robot's control system
- Adds current measurements to the observation state vector
- Helps detect contact events, object weights, and mechanical resistance
- Useful for contact-rich manipulation tasks
#### Combined Observation Processing
You can enable multiple observation processing features simultaneously:
```json
{
"env": {
"processor": {
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": true,
"add_ee_pose_to_observation": false,
"display_cameras": false
}
}
}
}
```
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
### Finding Robot Workspace Bounds
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
@@ -130,22 +349,56 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
1. Set `mode` to `"record"`
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
3. Set `num_episodes` to the number of demonstrations you want to collect
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
5. Configure `robot`, `cameras`, and other hardware settings
1. Set `mode` to `"record"` at the root level
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
Example configuration section:
```json
"mode": "record",
"repo_id": "username/pick_lift_cube",
"dataset_root": null,
"task": "pick_and_lift",
"num_episodes": 15,
"episode": 0,
"push_to_hub": true
{
"env": {
"type": "gym_manipulator",
"name": "real_robot",
"fps": 10,
"processor": {
"control_mode": "gamepad",
"observation": {
"display_cameras": false
},
"image_preprocessing": {
"crop_params_dict": {},
"resize_size": [128, 128]
},
"gripper": {
"use_gripper": true,
"gripper_penalty": 0.0
},
"reset": {
"reset_time_s": 5.0,
"control_time_s": 20.0
}
},
"robot": {
// ... robot configuration ...
},
"teleop": {
// ... teleoperator configuration ...
}
},
"dataset": {
"repo_id": "username/pick_lift_cube",
"root": null,
"task": "pick_and_lift",
"num_episodes_to_record": 15,
"replay_episode": 0,
"push_to_hub": true
},
"mode": "record",
"device": "cpu"
}
```
### Using a Teleoperation Device
@@ -191,10 +444,20 @@ The gamepad provides a very convenient way to control the robot and the episode
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
```json
{
"env": {
"teleop": {
"type": "gamepad",
"use_gripper": true
"type": "gamepad",
"use_gripper": true
},
"processor": {
"control_mode": "gamepad",
"gripper": {
"use_gripper": true
}
}
}
}
```
<p align="center">
@@ -216,11 +479,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
```json
{
"env": {
"teleop": {
"type": "so101_leader",
"port": "/dev/tty.usbmodem585A0077921", # check your port number
"use_degrees": true
"type": "so101_leader",
"port": "/dev/tty.usbmodem585A0077921",
"use_degrees": true
},
"processor": {
"control_mode": "leader",
"gripper": {
"use_gripper": true
}
}
}
}
```
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
@@ -251,7 +524,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
During recording:
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
2. Complete the task successfully
3. The episode ends with a reward of 1 when you press the "success" button
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
@@ -310,11 +583,19 @@ observation.images.front: [180, 250, 120, 150]
Add these crop parameters to your training configuration:
```json
"crop_params_dict": {
"observation.images.side": [180, 207, 180, 200],
"observation.images.front": [180, 250, 120, 150]
},
"resize_size": [128, 128]
{
"env": {
"processor": {
"image_preprocessing": {
"crop_params_dict": {
"observation.images.side": [180, 207, 180, 200],
"observation.images.front": [180, 250, 120, 150]
},
"resize_size": [128, 128]
}
}
}
}
```
**Recommended image resolution**
@@ -343,26 +624,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
**Key Parameters for Data Collection**
- **mode**: set it to `"record"` to collect a dataset
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **num_episodes**: Number of episodes to record
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
- **fps**: Number of frames per second to record
- **push_to_hub**: Whether to push the dataset to the hub
- **mode**: set it to `"record"` to collect a dataset (at root level)
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
- **dataset.num_episodes_to_record**: Number of episodes to record
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
- **env.fps**: Number of frames per second to record
- **dataset.push_to_hub**: Whether to push the dataset to the hub
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
Example configuration section for data collection:
```json
{
"env": {
"type": "gym_manipulator",
"name": "real_robot",
"fps": 10,
"processor": {
"reset": {
"reset_time_s": 5.0,
"control_time_s": 20.0,
"terminate_on_success": false
},
"gripper": {
"use_gripper": true
}
},
"robot": {
// ... robot configuration ...
},
"teleop": {
// ... teleoperator configuration ...
}
},
"dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes_to_record": 20,
"replay_episode": null,
"push_to_hub": true
},
"mode": "record",
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"num_episodes": 20,
"push_to_hub": true,
"fps": 10,
"number_of_steps_after_success": 15
"device": "cpu"
}
```
@@ -421,9 +728,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
<!-- prettier-ignore-start -->
```python
env_config = HILSerlRobotEnvConfig(
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
# Other environment parameters
config = GymManipulatorConfig(
env=HILSerlRobotEnvConfig(
processor=HILSerlProcessorConfig(
reward_classifier=RewardClassifierConfig(
pretrained_path="path_to_your_pretrained_trained_model"
)
),
# Other environment parameters
),
dataset=DatasetConfig(...),
mode=None # For training
)
```
<!-- prettier-ignore-end -->
@@ -432,7 +747,18 @@ or set the argument in the json config file.
```json
{
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
"env": {
"processor": {
"reward_classifier": {
"pretrained_path": "path_to_your_pretrained_model",
"success_threshold": 0.7,
"success_reward": 1.0
},
"reset": {
"terminate_on_success": true
}
}
}
}
```
+56 -30
View File
@@ -32,9 +32,12 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
```json
{
"type": "hil",
"name": "franka_sim",
"task": "PandaPickCubeGamepad-v0",
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"device": "cuda"
}
```
@@ -45,28 +48,40 @@ Available tasks:
- `PandaPickCubeGamepad-v0`: With gamepad control
- `PandaPickCubeKeyboard-v0`: With keyboard control
### Gym Wrappers Configuration
### Processor Configuration
```json
"wrapper": {
"gripper_penalty": -0.02,
"control_time_s": 15.0,
"use_gripper": true,
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
"end_effector_step_sizes": {
"x": 0.025,
"y": 0.025,
"z": 0.025
},
"control_mode": "gamepad"
{
"env": {
"processor": {
"control_mode": "gamepad",
"gripper": {
"use_gripper": true,
"gripper_penalty": -0.02
},
"reset": {
"control_time_s": 15.0,
"fixed_reset_joint_positions": [
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
]
},
"inverse_kinematics": {
"end_effector_step_sizes": {
"x": 0.025,
"y": 0.025,
"z": 0.025
}
}
}
}
}
```
Important parameters:
- `gripper_penalty`: Penalty for excessive gripper movement
- `use_gripper`: Whether to enable gripper control
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
- `gripper.use_gripper`: Whether to enable gripper control
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
## Running with HIL RL of LeRobot
@@ -75,39 +90,50 @@ Important parameters:
To run the environment, set mode to null:
<!-- prettier-ignore-start -->
```python
```bash
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
<!-- prettier-ignore-end -->
### Recording a Dataset
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
<!-- prettier-ignore-start -->
```python
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0"
},
"dataset": {
"repo_id": "username/sim_dataset",
"root": null,
"task": "pick_cube",
"num_episodes_to_record": 10,
"replay_episode": null,
"push_to_hub": true
},
"mode": "record"
}
```
```bash
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
<!-- prettier-ignore-end -->
### Training a Policy
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
<!-- prettier-ignore-start -->
```python
```bash
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
```
<!-- prettier-ignore-end -->
In a different terminal, run the learner server:
<!-- prettier-ignore-start -->
```python
```bash
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
```
<!-- prettier-ignore-end -->
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
+13 -2
View File
@@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.record import record_loop
from lerobot.policies.factory import make_processor
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
@@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig(
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/eval_<dataset_repo_id>",
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -559,6 +562,12 @@ _init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_processor(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
@@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES):
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
+53 -5
View File
@@ -24,11 +24,36 @@ pip install -e ".[hilserl]"
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_gym",
"root": null,
"task": "pick_cube",
"num_episodes_to_record": 30,
"replay_episode": null,
"push_to_hub": true
},
"mode": "record",
"device": "cuda"
}
```
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
Key configuration points:
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
- Ensure `mode` is set to `"record"`
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
Then we can run this command to start:
@@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
## Evaluate your policy in Sim
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
Here's an example evaluation configuration:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_sim_dataset",
"dataset_root": null,
"task": "pick_cube"
},
"pretrained_policy_name_or_path": "your_username/il_sim_model",
"device": "cuda"
}
```
Make sure to replace:
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
Then you can run this command to visualize your trained policy
+13 -2
View File
@@ -1,6 +1,7 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.record import record_loop
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.utils.control_utils import init_keyboard_listener
@@ -11,12 +12,14 @@ NUM_EPISODES = 2
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
robot = LeKiwiClient(robot_config)
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<eval_dataset_repo_id>",
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -43,6 +46,12 @@ listener, events = init_keyboard_listener()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
@@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
+1 -1
View File
@@ -38,7 +38,7 @@ while True:
keyboard_keys = keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
log_rerun_data(observation, {**arm_action, **base_action})
log_rerun_data(observation=observation, action={**arm_action, **base_action})
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
+158
View File
@@ -0,0 +1,158 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
observation_to_transition,
transition_to_robot_action,
)
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot with degrees
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Initialize the robot
robot = SO100Follower(robot_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints_processor = RobotProcessorPipeline(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=lambda tr: tr,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose_processor = RobotProcessorPipeline(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=observation_to_transition,
to_output=lambda tr: tr,
)
# Build dataset action and gripper features
action_ee_and_gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints_processor,
initial_features={},
use_videos=True,
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
) # Get all ee action features + gripper pos action features
# Build dataset observation features
obs_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose_processor,
initial_features=robot.observation_features,
use_videos=True,
patterns=["observation.state.ee"],
) # Get all ee observation features
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
print("All dataset features: ", dataset_features)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
# Connect the robot and teleoperator
robot.connect()
episode_idx = 0
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
dataset.save_episode()
# Clean up
log_say("Stop recording")
robot.disconnect()
dataset.push_to_hub()
+215
View File
@@ -0,0 +1,215 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
action_to_transition,
observation_to_transition,
transition_to_robot_action,
)
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
EEBoundsAndSafety,
EEReferenceAndDelta,
ForwardKinematicsJointsToEE,
GripperVelocityToJoint,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
NUM_EPISODES = 10
FPS = 30
EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 30
TASK_DESCRIPTION = "My task description"
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot and teleoperator
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
max_ee_twist_step_rad=0.50,
),
],
to_transition=action_to_transition,
to_output=lambda tr: tr,
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints_processor = RobotProcessorPipeline(
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
GripperVelocityToJoint(
motor_names=list(robot.bus.motors.keys()),
speed_factor=20.0,
),
],
to_transition=lambda tr: tr,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessorPipeline(
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=observation_to_transition,
to_output=lambda tr: tr,
)
# Build dataset ee action features
action_ee = aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor,
initial_features=phone.action_features,
use_videos=True,
patterns=["action.ee"],
)
# Get gripper pos action features
gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints_processor,
initial_features={},
use_videos=True,
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
)
# Build dataset ee observation features
observation_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=robot.observation_features,
use_videos=True,
patterns=["observation.state.ee"],
)
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
print("All dataset features: ", dataset_features)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
# Connect the robot and teleoperator
robot.connect()
phone.connect()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
dataset.push_to_hub()
+81
View File
@@ -0,0 +1,81 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
)
robot = SO100Follower(robot_config)
robot.connect()
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
actions = dataset.hf_dataset.select_columns("action")
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert ee pose action to joint action
robot_ee_to_joints_processor = RobotProcessorPipeline(
steps=[
AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=action_to_transition,
to_output=transition_to_robot_action,
)
robot_ee_to_joints_processor.reset()
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(dataset.num_frames):
t0 = time.perf_counter()
ee_action = {
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
}
joint_action = robot_ee_to_joints_processor(ee_action)
action_sent = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
robot.disconnect()
+93
View File
@@ -0,0 +1,93 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specif
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData,
EEBoundsAndSafety,
EEReferenceAndDelta,
GripperVelocityToJoint,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
# Initialize the robot and teleoperator
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop_device = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action to joint action
phone_to_robot_joints_processor = RobotProcessorPipeline(
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
),
GripperVelocityToJoint(
motor_names=list(robot.bus.motors.keys()),
speed_factor=20.0,
),
],
to_transition=action_to_transition,
to_output=transition_to_robot_action,
)
robot.connect()
teleop_device.connect()
print("Starting teleop loop. Move your phone to teleoperate the robot.")
while True:
# Get teleop observation
phone_obs = teleop_device.get_action()
# Phone -> EE pose -> Joints transition
joint_action = phone_to_robot_joints_processor(phone_obs)
if joint_action:
robot.send_action(joint_action)
time.sleep(0.01)
+4 -2
View File
@@ -95,7 +95,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
transformers-dep = ["transformers<=4.52.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
@@ -111,6 +111,7 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
@@ -152,7 +153,8 @@ all = [
"lerobot[video_benchmark]",
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[xarm]"
"lerobot[xarm]",
"lerobot[phone]",
]
[project.scripts]
+1 -2
View File
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
+1
View File
@@ -24,6 +24,7 @@ class FeatureType(str, Enum):
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
LANGUAGE = "LANGUAGE"
class NormalizationMode(str, Enum):
+9
View File
@@ -21,8 +21,14 @@ OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language"
ACTION = "action"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
ROBOTS = "robots"
ROBOT_TYPE = "robot_type"
@@ -39,6 +45,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
+95
View File
@@ -0,0 +1,95 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline
def aggregate_pipeline_dataset_features(
pipeline: DataProcessorPipeline,
initial_features: dict[str, Any],
*,
use_videos: bool = True,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates the pipeline's features and returns a features dict ready for the dataset,
filtered to only those keys matching any of the given patterns (for action/state only).
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
- `use_videos`: whether to treat image features as video streams
- `patterns`: regexes to filter action & state features; images are included
whenever use_videos=True, regardless of patterns.
"""
import re
# Gather everything the pipeline features specifies, seeded with hardware cams:
all_features = pipeline.transform_features(initial_features)
# Helper to decide which action/state keys survive the `patterns` filter:
def keep(key: str) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
# Start with hardware dict, injecting initial cameras if videos are ON:
hw: dict[str, dict[str, Any]] = {}
if use_videos:
cams = {
name: shape
for name, shape in initial_features.items()
if isinstance(shape, tuple) and len(shape) == 3
}
if cams:
hw["observation"] = dict(cams)
# Go over every feature from the pipeline and merge:
for full_key, ty in all_features.items():
if full_key.startswith(f"{ACTION}."):
# action.<feat>
if not keep(full_key):
continue
name = full_key[len(f"{ACTION}.") :]
hw.setdefault(ACTION, {})[name] = ty
elif full_key.startswith(f"{OBS_STATE}."):
# observation.state.<feat>
if not keep(full_key):
continue
name = full_key[len(f"{OBS_STATE}.") :]
hw.setdefault("observation", {})[name] = ty
elif full_key.startswith(f"{OBS_IMAGES}."):
# observation.images.<cam>
# images obey ONLY the use_videos flag, not patterns
if not use_videos:
continue
name = full_key[len(f"{OBS_IMAGES}.") :]
hw.setdefault("observation", {})[name] = ty
else:
# anything else (e.g. policy-only features) is ignored here
continue
out: dict[str, dict] = {}
if ACTION in hw:
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
if "observation" in hw:
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
return out
+44
View File
@@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
return policy_features
def combine_feature_dicts(*dicts: dict) -> dict:
"""
Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (observation.images.*), last one wins (if they are identical).
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out
def create_empty_dataset_info(
codebase_version: str,
fps: int,
+57 -86
View File
@@ -161,35 +161,73 @@ class XarmEnv(EnvConfig):
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
class ImagePreprocessingConfig:
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
@dataclass
class EnvTransformConfig:
"""Configuration for environment wrappers."""
class RewardClassifierConfig:
"""Configuration for reward classification."""
pretrained_path: str | None = None
success_threshold: float = 0.5
success_reward: float = 1.0
@dataclass
class InverseKinematicsConfig:
"""Configuration for inverse kinematics processing."""
urdf_path: str | None = None
target_frame_name: str | None = None
end_effector_bounds: dict[str, list[float]] | None = None
end_effector_step_sizes: dict[str, float] | None = None
@dataclass
class ObservationConfig:
"""Configuration for observation processing."""
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
control_mode: str = "gamepad"
display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Any | None = None
reset_time_s: float = 5.0
display_cameras: bool = False
@dataclass
class GripperConfig:
"""Configuration for gripper control and penalties."""
use_gripper: bool = True
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@dataclass
class ResetConfig:
"""Configuration for environment reset behavior."""
fixed_reset_joint_positions: Any | None = None
reset_time_s: float = 5.0
control_time_s: float = 20.0
terminate_on_success: bool = True
@dataclass
class HILSerlProcessorConfig:
"""Configuration for environment processing pipeline."""
control_mode: str = "gamepad"
observation: ObservationConfig | None = None
image_preprocessing: ImagePreprocessingConfig | None = None
gripper: GripperConfig | None = None
reset: ResetConfig | None = None
inverse_kinematics: InverseKinematicsConfig | None = None
reward_classifier: RewardClassifierConfig | None = None
max_gripper_pos: float | None = 100.0
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
@@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
fps: int = 10
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
name: str = "real_robot"
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
task: str | None = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None
reward_classifier_pretrained_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
@property
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
name: str = "PandaPickCube"
task: str | None = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_STATE,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: str | None = None
robot_config: RobotConfig | None = None
teleop_config: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}
+1 -3
View File
@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
+11
View File
@@ -15,6 +15,17 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
]
-16
View File
@@ -35,7 +35,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.constants import ACTION, OBS_IMAGES
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy):
def __init__(
self,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
@@ -137,23 +125,19 @@ class ACTPolicy(PreTrainedPolicy):
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
+70
View File
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessorStep,
UnnormalizerProcessorStep,
)
def make_act_pre_post_processors(
config: ACTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -35,7 +35,6 @@ from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
@@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
def __init__(
self,
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -106,9 +96,6 @@ class DiffusionPolicy(PreTrainedPolicy):
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -137,7 +124,6 @@ class DiffusionPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
@@ -153,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
return loss, None
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessorStep,
UnnormalizerProcessorStep,
)
def make_diffusion_pre_post_processors(
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+169 -5
View File
@@ -14,12 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from __future__ import annotations
from torch import nn
import logging
from typing import Any, TypedDict
import torch
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
@@ -34,9 +39,10 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
def get_policy_class(name: str) -> PreTrainedPolicy:
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -101,6 +107,163 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
raise ValueError(f"Policy type '{policy_type}' is not available.")
class ProcessorConfigKwargs(TypedDict, total=False):
"""Keyword arguments for the processor config."""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
preprocessor_kwargs: ProcessorKwargs | None
postprocessor_kwargs: ProcessorKwargs | None
def make_pre_post_processors(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Make a processor instance for a given policy type.
This function creates the appropriate processor configuration based on the policy type.
Each policy type has its own processor with specific preprocessing steps.
Args:
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
the processor from this path instead of creating a new one.
**kwargs: Additional keyword arguments passed to the processor creation.
Returns:
Tuple of (input_processor, output_processor) for the policy.
Raises:
NotImplementedError: If the policy type doesn't have a processor implemented.
"""
if pretrained_path:
# Extract preprocessor and postprocessor kwargs
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=preprocessor_kwargs.get("to_transition"),
to_output=preprocessor_kwargs.get("to_output"),
),
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=postprocessor_kwargs.get("to_transition"),
to_output=postprocessor_kwargs.get("to_output"),
),
)
# Create a new processor based on policy type
if isinstance(policy_cfg, TDMPCConfig):
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
processors = make_tdmpc_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, DiffusionConfig):
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
processors = make_diffusion_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, ACTConfig):
from lerobot.policies.act.processor_act import make_act_pre_post_processors
processors = make_act_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
processors = make_vqbet_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
processors = make_pi0_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
processors = make_pi0fast_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
processors = make_sac_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, RewardClassifierConfig):
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
processors = make_smolvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
return processors
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
@@ -147,7 +310,6 @@ def make_policy(
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
@@ -155,6 +317,8 @@ def make_policy(
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
if env_cfg is None:
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
@@ -171,7 +335,7 @@ def make_policy(
policy = policy_cls(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, nn.Module)
assert isinstance(policy, torch.nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
-420
View File
@@ -1,420 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
assert isinstance(norm_mode, NormalizationMode)
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
# we assert they are not infinity anymore.
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
"std": nn.Parameter(std, requires_grad=False),
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
"max": nn.Parameter(max, requires_grad=False),
}
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model."
)
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# TODO: Remove this shallow copy
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch
+7 -139
View File
@@ -56,18 +56,15 @@ from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.utils import get_safe_dtype, init_logging
from lerobot.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
@@ -223,28 +220,17 @@ class PI0Policy(PreTrainedPolicy):
def __init__(
self,
config: PI0Config,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
@@ -253,99 +239,6 @@ class PI0Policy(PreTrainedPolicy):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
"""
Transform state dict keys to match expected model structure.
Transformations:
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
model.paligemma_with_expert.paligemma.lm_head
- model.paligemma_with_expert.paligemma.language_model.model ->
model.paligemma_with_expert.paligemma.model.language_model
- model.paligemma_with_expert.paligemma.vision_tower ->
model.paligemma_with_expert.paligemma.model.vision_tower
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
model.paligemma_with_expert.paligemma.model.multi_modal_projector
Also handles tied weights between lm_head.weight and
embed_tokens.weight.
"""
import re
transformed_dict = {}
transformations = [
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
".paligemma_with_expert.paligemma.lm_head",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
".paligemma_with_expert.paligemma.model.language_model",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
".paligemma_with_expert.paligemma.model.vision_tower",
),
(
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
".paligemma_with_expert.paligemma.model.multi_modal_projector",
),
]
for key, value in state_dict.items():
new_key = key
for pattern, replacement in transformations:
new_key = pattern.sub(replacement, new_key)
transformed_dict[new_key] = value
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
lm_head_key = None
embed_tokens_key = None
for key in transformed_dict:
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
lm_head_key = key
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
embed_tokens_key = key
if lm_head_key and embed_tokens_key:
break
if lm_head_key and not embed_tokens_key:
embed_tokens_key = lm_head_key.replace(
".lm_head.weight", ".model.language_model.embed_tokens.weight"
)
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
elif embed_tokens_key and not lm_head_key:
lm_head_key = embed_tokens_key.replace(
".model.language_model.embed_tokens.weight", ".lm_head.weight"
)
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
return transformed_dict
@classmethod
def _load_as_safetensor(
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
) -> "PI0Policy":
"""Override to apply key transformations before loading."""
from safetensors.torch import load_file
init_logging()
# Load the state dict from file safely
state_dict = load_file(model_file, device=map_location)
# Apply key transformations
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
# Load the transformed state dict
msg = model.load_state_dict(transformed_state_dict, strict=strict)
# Log message
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
return model
def get_optim_params(self) -> dict:
return self.parameters()
@@ -377,14 +270,13 @@ class PI0Policy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
@@ -394,8 +286,6 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -410,12 +300,10 @@ class PI0Policy(PreTrainedPolicy):
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
@@ -482,26 +370,6 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
@@ -567,7 +435,7 @@ class PI0FlowMatching(nn.Module):
└──────────────────────────────┘
"""
def __init__(self, config):
def __init__(self, config: PI0Config):
super().__init__()
self.config = config
+118
View File
@@ -0,0 +1,118 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
ProcessorStep,
ProcessorStepRegistry,
RenameProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def make_pi0_pre_post_processors(
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -58,7 +58,6 @@ from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -146,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
@@ -221,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
@@ -235,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -249,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessorStep,
UnnormalizerProcessorStep,
)
def make_pi0fast_pre_post_processors(
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+10 -53
View File
@@ -28,7 +28,6 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.policies.normalize import NormalizeBuffer
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters
@@ -45,7 +44,6 @@ class SACPolicy(
def __init__(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__(config)
config.validate_features()
@@ -53,7 +51,6 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
self._init_normalization(dataset_stats)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
@@ -88,8 +85,7 @@ class SACPolicy(
observations_features = None
if self.shared_encoder and self.actor.encoder.has_images:
# Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
observations_features = self.actor.encoder.get_cached_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
@@ -391,28 +387,12 @@ class SACPolicy(
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_normalization(self, dataset_stats):
"""Initialize input/output normalization modules."""
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
if self.config.dataset_stats is not None:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = NormalizeBuffer(
self.config.input_features, self.config.normalization_mapping, params
)
stats = dataset_stats or params
self.normalize_targets = NormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats
)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic
if self.shared_encoder
else SACObservationEncoder(self.config, self.normalize_inputs)
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
@@ -424,9 +404,7 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
)
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
@@ -434,9 +412,7 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
)
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
@@ -490,10 +466,9 @@ class SACPolicy(
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
def __init__(self, config: SACConfig) -> None:
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
@@ -568,11 +543,10 @@ class SACObservationEncoder(nn.Module):
def forward(
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
) -> Tensor:
obs = self.input_normalization(obs)
parts = []
if self.has_images:
if cache is None:
cache = self.get_cached_image_features(obs, normalize=False)
cache = self.get_cached_image_features(obs)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
@@ -585,7 +559,7 @@ class SACObservationEncoder(nn.Module):
"No parts to concatenate, you should have at least one image or environment state or state"
)
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
"""Extract and optionally cache image features from observations.
This function processes image observations through the vision encoder once and returns
@@ -597,26 +571,17 @@ class SACObservationEncoder(nn.Module):
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
Normalization behavior:
- When called from inside forward(): set normalize=False since inputs are already normalized
- When called from outside forward(): set normalize=True to ensure proper input normalization
Usage patterns:
- Called in select_action() with normalize=True
- Called in select_action()
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward() with normalize=False
- Called internally by forward()
Args:
obs: Dictionary of observation tensors containing image keys
normalize: Whether to normalize observations before encoding
Set to True when calling directly from outside the encoder's forward method
Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
if normalize:
obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
@@ -747,7 +712,6 @@ class CriticEnsemble(nn.Module):
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
output_normalization (nn.Module): normalization layer for actions.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
@@ -757,13 +721,11 @@ class CriticEnsemble(nn.Module):
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
output_normalization: nn.Module,
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
def forward(
@@ -775,11 +737,6 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = self.encoder(observations, cache=observation_features)
+71
View File
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessorStep,
UnnormalizerProcessorStep,
)
def make_sac_pre_post_processors(
config: SACConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -20,7 +20,6 @@ import torch
from torch import Tensor, nn
from lerobot.constants import OBS_IMAGE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy):
def __init__(
self,
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
@@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
@@ -0,0 +1,61 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.processor import (
DeviceProcessorStep,
IdentityProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
)
def make_classifier_processor(
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
NormalizerProcessorStep(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
]
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
return (
PolicyProcessorPipeline(
steps=input_steps,
name="classifier_preprocessor",
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name="classifier_postprocessor",
**postprocessor_kwargs,
),
)
@@ -53,21 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
"""
import math
import os
import re
from collections import deque
import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
@@ -76,102 +68,6 @@ from lerobot.policies.utils import (
)
from lerobot.utils.utils import get_safe_dtype
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
Re-keys `checkpoint ` so that every entry matches the *reference* key set.
If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []
for k, v in checkpoint.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)
if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}'{variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.
Args:
checkpoint (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
Returns:
dict: The modified checkpoint with renamed keys.
"""
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
new_checkpoint = {}
for k, v in checkpoint.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint
def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
device: str = "cpu",
checkpoint_keys_mapping: str = "",
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),
len(unexpected),
)
return model
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -326,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy):
def __init__(
self,
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
@@ -357,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
@classmethod
def _load_as_safetensor(
cls,
model: "SmolVLAPolicy",
model_file: str,
map_location: str,
strict: bool,
):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return load_smolvla(
model,
model_file,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)
def get_optim_params(self) -> dict:
return self.parameters()
@@ -389,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
@@ -397,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -408,8 +275,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
return batch
@torch.no_grad()
@@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
@@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
@@ -0,0 +1,111 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
ProcessorStepRegistry,
RenameProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
def make_smolvla_pre_post_processors(
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
SmolVLANewLineProcessor(),
TokenizerProcessorStep(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
padding_side="right",
max_length=config.tokenizer_max_length,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
"""Add a new line to the end of the task if it doesn't have one."""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+7 -17
View File
@@ -36,7 +36,6 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
@@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
def __init__(
self,
config: TDMPCConfig,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
@@ -137,7 +129,6 @@ class TDMPCPolicy(PreTrainedPolicy):
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -147,11 +138,12 @@ class TDMPCPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -320,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessorStep,
UnnormalizerProcessorStep,
)
def make_tdmpc_pre_post_processors(
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+3 -14
View File
@@ -28,7 +28,6 @@ import torchvision
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
def __init__(
self,
config: VQBeTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.vqbet = VQBeTModel(config)
self.reset()
@@ -128,7 +118,6 @@ class VQBeTPolicy(PreTrainedPolicy):
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
@@ -142,10 +131,12 @@ class VQBeTPolicy(PreTrainedPolicy):
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -165,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
# loss: total loss of training RVQ
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
RenameProcessorStep,
UnnormalizerProcessorStep,
)
def make_vqbet_pre_post_processors(
config: VQBeTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return (
PolicyProcessorPipeline(
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+76 -27
View File
@@ -14,41 +14,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import VanillaObservationProcessor
from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
create_transition,
merge_transitions,
transition_to_batch,
transition_to_dataset_frame,
)
from .core import EnvTransition, TransitionKey
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .device_processor import DeviceProcessorStep
from .gym_action_processor import Numpy2TorchActionProcessorStep, Torch2NumpyActionProcessorStep
from .hil_processor import (
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
GripperPenaltyProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
RewardClassifierProcessorStep,
TimeLimitProcessorStep,
)
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
from .observation_processor import VanillaObservationProcessorStep
from .pipeline import (
ActionProcessor,
DoneProcessor,
EnvTransition,
IdentityProcessor,
InfoProcessor,
ObservationProcessor,
ActionProcessorStep,
ComplementaryDataProcessorStep,
DataProcessorPipeline,
DoneProcessorStep,
IdentityProcessorStep,
InfoProcessorStep,
ObservationProcessorStep,
PolicyProcessorPipeline,
ProcessorKwargs,
ProcessorStep,
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionKey,
TruncatedProcessor,
RewardProcessorStep,
RobotProcessorPipeline,
TruncatedProcessorStep,
)
from .rename_processor import RenameProcessor
from .rename_processor import RenameProcessorStep
from .tokenizer_processor import TokenizerProcessorStep
__all__ = [
"ActionProcessor",
"DeviceProcessor",
"DoneProcessor",
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",
"AddTeleopEventsAsInfoStep",
"ComplementaryDataProcessorStep",
"batch_to_transition",
"create_transition",
"DeviceProcessorStep",
"DoneProcessorStep",
"EnvTransition",
"IdentityProcessor",
"InfoProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"ObservationProcessor",
"GripperPenaltyProcessorStep",
"hotswap_stats",
"IdentityProcessorStep",
"ImageCropResizeProcessorStep",
"InfoProcessorStep",
"InterventionActionProcessorStep",
"JointVelocityProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"merge_transitions",
"MotorCurrentProcessorStep",
"NormalizerProcessorStep",
"Numpy2TorchActionProcessorStep",
"ObservationProcessorStep",
"PolicyProcessorPipeline",
"ProcessorKwargs",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"RenameProcessorStep",
"RewardClassifierProcessorStep",
"RewardProcessorStep",
"DataProcessorPipeline",
"TimeLimitProcessorStep",
"AddBatchDimensionProcessorStep",
"RobotProcessorPipeline",
"TokenizerProcessorStep",
"Torch2NumpyActionProcessorStep",
"transition_to_batch",
"transition_to_dataset_frame",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
"TruncatedProcessorStep",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]
+159
View File
@@ -0,0 +1,159 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .core import EnvTransition
from .pipeline import (
ActionProcessorStep,
ComplementaryDataProcessorStep,
ObservationProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
)
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_action")
class AddBatchDimensionActionStep(ActionProcessorStep):
"""Process action component in-place, adding batch dimension if needed."""
def action(self, action):
if not isinstance(action, Tensor) or action.dim() != 1:
return action
return action.unsqueeze(0)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
class AddBatchDimensionObservationStep(ObservationProcessorStep):
"""Process observation component in-place, adding batch dimensions where needed."""
def observation(self, observation):
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
state_value = observation[state_key]
if isinstance(state_value, Tensor) and state_value.dim() == 1:
observation[state_key] = state_value.unsqueeze(0)
# Process single image observation - add batch dim if 3D
if OBS_IMAGE in observation:
image_value = observation[OBS_IMAGE]
if isinstance(image_value, Tensor) and image_value.dim() == 3:
observation[OBS_IMAGE] = image_value.unsqueeze(0)
# Process multiple image observations - add batch dim if 3D
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
observation[key] = value.unsqueeze(0)
return observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
"""Process complementary data in-place, handling task field batching."""
def complementary_data(self, complementary_data):
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
return complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class AddBatchDimensionProcessorStep(ProcessorStep):
"""Processor that adds batch dimensions to observations and actions when needed.
This processor ensures that observations and actions have proper batch dimensions for model processing:
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
The processor only modifies tensors that need batching and leaves already
batched tensors unchanged.
Example:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
"""
to_batch_action_processor: AddBatchDimensionActionStep = field(
default_factory=AddBatchDimensionActionStep
)
to_batch_observation_processor: AddBatchDimensionObservationStep = field(
default_factory=AddBatchDimensionObservationStep
)
to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
default_factory=AddBatchDimensionComplementaryDataStep
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = self.to_batch_action_processor(transition)
transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition)
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# NOTE: We ignore the batch dimension when transforming features
return features
+481
View File
@@ -0,0 +1,481 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from copy import deepcopy
from functools import singledispatch
from typing import Any
import numpy as np
import torch
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from lerobot.utils.rotation import Rotation
from .core import EnvTransition, TransitionKey
@singledispatch
def to_tensor(
value: Any,
*,
dtype: torch.dtype | None = torch.float32,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""
Convert various data types to PyTorch tensors with configurable options.
This is a unified tensor conversion function using single dispatch to handle
different input types appropriately.
Args:
value: Input value to convert (tensor, array, scalar, sequence, etc.)
dtype: Target tensor dtype. If None, preserves original dtype.
device: Target device for the tensor.
Returns:
PyTorch tensor.
Raises:
TypeError: If the input type is not supported.
"""
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
@to_tensor.register(torch.Tensor)
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle existing PyTorch tensors."""
if dtype is not None:
value = value.to(dtype=dtype)
if device is not None:
value = value.to(device=device)
return value
@to_tensor.register(np.ndarray)
def _(
value: np.ndarray,
*,
dtype=torch.float32,
device=None,
**kwargs,
) -> torch.Tensor:
"""Handle numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
if value.ndim == 0:
# Numpy scalars should be converted to 0-dimensional tensors
scalar_value = value.item()
return torch.tensor(scalar_value, dtype=dtype, device=device)
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
tensor = torch.from_numpy(value)
# Apply dtype conversion if specified
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if device is not None:
tensor = tensor.to(device=device)
return tensor
@to_tensor.register(int)
@to_tensor.register(float)
@to_tensor.register(np.integer)
@to_tensor.register(np.floating)
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle scalar values including numpy scalars."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(list)
@to_tensor.register(tuple)
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle sequences (lists, tuples)."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(dict)
def _(value: dict, *, device=None, **kwargs) -> dict:
"""Handle dictionaries by recursively converting values to tensors."""
if not value:
return {}
result = {}
for key, sub_value in value.items():
if sub_value is None:
continue
if isinstance(sub_value, dict):
# Recursively process nested dictionaries
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
continue
# Convert individual values to tensors
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
return result
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
"""Convert tensor to numpy/scalar if needed."""
if isinstance(x, torch.Tensor):
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
return x
def _is_image(arr: Any) -> bool:
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
state, images = {}, {}
for k, v in obs.items():
if "image" in k.lower() or _is_image(v):
images[k] = v
else:
state[k] = v
return state, images
# ============================================================================
# Private Helper Functions (Common Logic)
# ============================================================================
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""Extract complementary data (pad flags, task, index, task_index)."""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key}
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
"""Merge two transitions, with other taking precedence."""
out = deepcopy(base)
for key in (
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
):
if other.get(key):
out.setdefault(key, {}).update(deepcopy(other[key]))
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
if k in other:
out[k] = other[k]
return out
# ============================================================================
# Core Conversion Functions
# ============================================================================
def create_transition(
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
reward: float = 0.0,
done: bool = False,
truncated: bool = False,
info: dict[str, Any] | None = None,
complementary_data: dict[str, Any] | None = None,
) -> EnvTransition:
"""Create an EnvTransition with sensible defaults.
Args:
observation: Observation dictionary.
action: Action dictionary.
reward: Scalar reward value.
done: Episode termination flag.
truncated: Episode truncation flag.
info: Additional info dictionary.
complementary_data: Complementary data dictionary.
Returns:
Complete EnvTransition dictionary.
"""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
"""
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
"""
act_dict: dict[str, Any] = {}
for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)):
act_dict[f"{ACTION}.{k}"] = v
continue
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
return create_transition(observation={}, action=act_dict)
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
"""
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
"""
state, images = _split_obs_to_state_and_images(observation)
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
for cam, img in images.items():
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
return create_transition(observation=obs_dict, action={})
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
"""
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
"""
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
if action_dict is None:
return out
for k, v in action_dict.items():
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
out[out_key] = float(v)
return out
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
"""Merge multiple transitions or return single transition.
Args:
transitions: Either a single transition or iterable of transitions.
Returns:
Merged EnvTransition.
"""
if not isinstance(transitions, Sequence): # Single transition
return transitions
items = list(transitions)
if not items:
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
result = items[0]
for t in items[1:]:
result = _merge_transitions(result, t)
return result
def transition_to_dataset_frame(
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
) -> dict[str, Any]:
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
Processes transitions according to the provided feature specification and returns
data in the format expected by machine learning models and datasets.
Args:
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
(which will be merged using merge_transitions).
features: Feature specification dictionary with the following structure:
- 'action': dict with 'names': list of action feature names
- 'observation.state': dict with 'names': list of state feature names
- keys starting with 'observation.images.' are passed through as-is
Returns:
Flat dictionary containing:
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
- any image tensors defined in features (passed through unchanged)
- next.{reward,done,truncated} scalar values
- info dict
- *_is_pad flags and task from complementary_data
"""
action_names = features.get(ACTION, {}).get("names", [])
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
tr = merge_transitions(transitions_or_transition)
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
act = tr.get(TransitionKey.ACTION, {}) or {}
batch: dict[str, Any] = {}
# Images passthrough
for k in image_keys:
if k in obs:
batch[k] = obs[k]
# Observation.state vector
if obs_state_names:
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
# Action vector
if action_names:
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch[ACTION] = np.asarray(vals, dtype=np.float32)
# Add transition metadata
if tr.get(TransitionKey.REWARD) is not None:
reward_val = _from_tensor(tr[TransitionKey.REWARD])
# Check if features expect array format, otherwise keep as scalar
if REWARD in features and features[REWARD].get("shape") == (1,):
batch[REWARD] = np.array([reward_val], dtype=np.float32)
else:
batch[REWARD] = reward_val
if tr.get(TransitionKey.DONE) is not None:
done_val = _from_tensor(tr[TransitionKey.DONE])
if DONE in features and features[DONE].get("shape") == (1,):
batch[DONE] = np.array([done_val], dtype=bool)
else:
batch[DONE] = done_val
if tr.get(TransitionKey.TRUNCATED) is not None:
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
else:
batch[TRUNCATED] = truncated_val
# Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if comp:
# pad flags
for k, v in comp.items():
if k.endswith("_is_pad"):
batch[k] = v
# task label
if comp.get("task") is not None:
batch["task"] = comp["task"]
return batch
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
The function maps well known keys to the EnvTransition structure. Missing keys are
filled with sane defaults (None or 0.0/False).
Keys recognised (case-sensitive):
* "observation.*" (keys starting with "observation." are grouped into observation dict)
* "action"
* "next.reward"
* "next.done"
* "next.truncated"
* "info"
* "_is_pad" patterns (padding flags)
* "task", "index", "task_index" (complementary data)
Additional keys are ignored so that existing dataloaders can carry extra
metadata without breaking the processor.
Args:
batch: Batch dictionary from datasets or dataloaders containing the above keys.
Returns:
EnvTransition dictionary with properly structured transition data.
"""
# Validate input type
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation keys
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
complementary_data = _extract_complementary_data(batch)
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get("action"),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
info=batch.get("info", {}),
complementary_data=complementary_data if complementary_data else None,
)
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
and other LeRobot components.
Output format:
* "action": Action data from transition
* "next.reward": Reward value (defaults to 0.0)
* "next.done": Done flag (defaults to False)
* "next.truncated": Truncated flag (defaults to False)
* "info": Info dictionary (defaults to {})
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
* Complementary data fields ("task", "index", "task_index", padding flags)
Args:
transition: EnvTransition dictionary to convert.
Returns:
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
"""
batch = {
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Add complementary data
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if comp_data:
batch.update(comp_data)
# Flatten observation dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
return batch
+49
View File
@@ -0,0 +1,49 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import Enum
from typing import Any, TypedDict
import torch
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
# TODO(Steven): Use consts
OBSERVATION = "observation"
ACTION = "action"
REWARD = "reward"
DONE = "done"
TRUNCATED = "truncated"
INFO = "info"
COMPLEMENTARY_DATA = "complementary_data"
EnvTransition = TypedDict(
"EnvTransition",
{
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
TransitionKey.ACTION.value: Any | torch.Tensor | None,
TransitionKey.REWARD.value: float | torch.Tensor | None,
TransitionKey.DONE.value: bool | torch.Tensor | None,
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
TransitionKey.INFO.value: dict[str, Any] | None,
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
},
)
@@ -0,0 +1,145 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@dataclass
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
"""
Map a tensor to a delta action dictionary.
"""
use_gripper: bool = True
def action(self, action: Tensor) -> dict:
if action.dim() > 1:
action = action.squeeze(0)
# TODO (maractingi): add rotation
delta_action = {
f"{ACTION}.delta_x": action[0],
f"{ACTION}.delta_y": action[1],
f"{ACTION}.delta_z": action[2],
}
if self.use_gripper:
delta_action[f"{ACTION}.gripper"] = action[3]
return delta_action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{ACTION}.delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
if self.use_gripper:
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
"""
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
for use with inverse kinematics processors.
Expected input ACTION keys:
{
"action.delta_x": float,
"action.delta_y": float,
"action.delta_z": float,
"action.gripper": float (optional),
}
Output ACTION keys:
{
"action.enabled": bool,
"action.target_x": float,
"action.target_y": float,
"action.target_z": float,
"action.target_wx": float,
"action.target_wy": float,
"action.target_wz": float,
"action.gripper": float,
}
"""
# Scale factors for delta movements
position_scale: float = 1.0
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
def action(self, action: dict) -> dict:
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
delta_x = action.pop(f"{ACTION}.delta_x", 0.0)
delta_y = action.pop(f"{ACTION}.delta_y", 0.0)
delta_z = action.pop(f"{ACTION}.delta_z", 0.0)
gripper = action.pop(f"{ACTION}.gripper", 1.0) # Default to "stay" (1.0)
# Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
# Scale the deltas appropriately
scaled_delta_x = delta_x * self.position_scale
scaled_delta_y = delta_y * self.position_scale
scaled_delta_z = delta_z * self.position_scale
# For gamepad/keyboard, we don't have rotation input, so set to 0
# These could be extended in the future for more sophisticated teleoperators
target_wx = 0.0
target_wy = 0.0
target_wz = 0.0
# Update action with robot target format
action = {
f"{ACTION}.enabled": enabled,
f"{ACTION}.target_x": scaled_delta_x,
f"{ACTION}.target_y": scaled_delta_y,
f"{ACTION}.target_z": scaled_delta_z,
f"{ACTION}.target_wx": target_wx,
f"{ACTION}.target_wy": target_wy,
f"{ACTION}.target_wz": target_wz,
f"{ACTION}.gripper": float(gripper),
}
return action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transform features to match output format."""
features.pop(f"{ACTION}.delta_x", None)
features.pop(f"{ACTION}.delta_y", None)
features.pop(f"{ACTION}.delta_z", None)
features.pop(f"{ACTION}.gripper", None)
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
+88 -36
View File
@@ -19,64 +19,116 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
from .core import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device.
class DeviceProcessorStep(ProcessorStep):
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned.
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
"""
device: torch.device = "cpu"
device: str = "cpu"
float_dtype: str | None = None
DTYPE_MAPPING = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
def __post_init__(self):
self.device = get_safe_torch_device(self.device)
self.tensor_device: torch.device = get_safe_torch_device(self.device)
self.device = self.tensor_device.type # cuda might have changed to cuda:1
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
if self.float_dtype not in self.DTYPE_MAPPING:
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
)
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype.
If the tensor is already on a GPU and we're configured for a GPU, it preserves
that GPU placement (useful for multi-GPU training with Accelerate).
Otherwise, it moves to the configured device.
"""
# Determine target device
if tensor.is_cuda and self.tensor_device.type == "cuda":
# Both tensor and target are on GPU - preserve tensor's GPU placement
# This handles multi-GPU scenarios where Accelerate has already placed
# tensors on the correct GPU for each process
target_device = tensor.device
else:
# Either tensor is on CPU, or we're configured for CPU
# In both cases, use the configured device
target_device = self.tensor_device
# Only move if necessary
if tensor.device != target_device:
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
# Convert float dtype if specified and tensor is floating point
if self._target_float_dtype is not None and tensor.is_floating_point():
tensor = tensor.to(dtype=self._target_float_dtype)
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
simple_tensor_keys = [
TransitionKey.ACTION,
TransitionKey.REWARD,
TransitionKey.DONE,
TransitionKey.TRUNCATED,
]
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
dict_tensor_keys = [
TransitionKey.OBSERVATION,
TransitionKey.COMPLEMENTARY_DATA,
]
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process simple tensors
for key in simple_tensor_keys:
value = transition.get(key)
if isinstance(value, torch.Tensor):
new_transition[key] = self._process_tensor(value)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
# Process dictionary-like tensors
for key in dict_tensor_keys:
data_dict = transition.get(key)
if data_dict is not None:
new_data_dict = {
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in data_dict.items()
}
new_transition[key] = new_data_dict
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device}
return {"device": self.device, "float_dtype": self.float_dtype}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -0,0 +1,72 @@
#! /usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
from dataclasses import dataclass
import numpy as np
import torch
from lerobot.configs.types import PolicyFeature
from .converters import to_tensor
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
"""Convert PyTorch tensor actions to NumPy arrays."""
squeeze_batch_dim: bool = True
def action(self, action: torch.Tensor) -> np.ndarray:
if not isinstance(action, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
if (
self.squeeze_batch_dim
and numpy_action.shape
and len(numpy_action.shape) > 1
and numpy_action.shape[0] == 1
):
numpy_action = numpy_action.squeeze(0)
return numpy_action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
"""Convert NumPy array action to PyTorch tensor."""
def action(self, action: np.ndarray) -> torch.Tensor:
if not isinstance(action, np.ndarray):
raise TypeError(
f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
return torch_action
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+382
View File
@@ -0,0 +1,382 @@
import math
import time
from dataclasses import dataclass
from typing import Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.constants import ACTION
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from .core import EnvTransition, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
InfoProcessorStep,
ObservationProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
TruncatedProcessorStep,
)
GRIPPER_KEY = "gripper"
DISCRETE_PENALTY_KEY = "discrete_penalty"
TELEOP_ACTION_KEY = "teleop_action"
@runtime_checkable
class HasTeleopEvents(Protocol):
"""Minimal protocol for objects that provide teleoperation events.
This protocol only defines the additional get_teleop_events() method,
avoiding duplication of the entire Teleoperator interface.
"""
def get_teleop_events(self) -> dict[str, Any]:
"""Get extra control events from the teleoperator.
Returns:
Dictionary containing control events such as:
- is_intervention: bool - Whether human is currently intervening
- terminate_episode: bool - Whether to terminate the current episode
- success: bool - Whether the episode was successful
- rerecord_episode: bool - Whether to rerecord the episode
"""
...
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
def _check_teleop_with_events(teleop: Teleoperator) -> None:
"""Runtime check that a teleoperator implements get_teleop_events."""
if not isinstance(teleop, HasTeleopEvents):
raise TypeError(
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
)
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
"""Add teleoperator action to transition complementary data."""
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict) -> dict:
new_complementary_data = dict(complementary_data)
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
"""Add teleoperator control events to transition info.
This processor step extracts control events from teleoperators that support
event-based interaction (intervention detection, episode termination, etc.).
Works with any teleoperator that inherits from Teleoperator and implements the
get_teleop_events() method, including custom user-defined teleoperators.
Built-in compatible teleoperators:
- GamepadTeleop: Uses gamepad buttons for control events
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
"""
teleop_device: TeleopWithEvents
def __post_init__(self):
"""Validate that the teleoperator supports events."""
_check_teleop_with_events(self.teleop_device)
def info(self, info: dict) -> dict:
new_info = dict(info)
teleop_events = self.teleop_device.get_teleop_events()
new_info.update(teleop_events)
return new_info
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
class ImageCropResizeProcessorStep(ObservationProcessorStep):
"""Crop and resize image observations."""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict) -> dict:
if self.resize_size is None and not self.crop_params_dict:
return observation
new_observation = dict(observation)
# Process all image keys in the observation
for key in observation:
if "image" not in key:
continue
image = observation[key]
device = image.device
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
if device.type == "mps":
image = image.cpu()
# Crop if crop params are provided for this key
if self.crop_params_dict is not None and key in self.crop_params_dict:
crop_params = self.crop_params_dict[key]
image = F.crop(image, *crop_params)
if self.resize_size is not None:
image = F.resize(image, self.resize_size)
image = image.clamp(0.0, 1.0)
new_observation[key] = image.to(device)
return new_observation
def get_config(self) -> dict[str, Any]:
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if self.resize_size is None:
return features
for key in features:
if "image" in key:
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessorStep(TruncatedProcessorStep):
"""Track episode steps and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def truncated(self, truncated):
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
# TODO (steven): missing an else truncated = False?
return truncated
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def reset(self) -> None:
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""Apply penalty for inappropriate gripper usage."""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data):
"""Calculate gripper penalty and add to complementary data."""
action = self.transition.get(TransitionKey.ACTION)
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return complementary_data
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
return new_complementary_data
def get_config(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessorStep(ProcessorStep):
"""Handle human intervention actions and episode termination."""
use_gripper: bool = False
terminate_on_success: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
# Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
success = info.get(TeleopEvents.SUCCESS, False)
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
new_transition = transition.copy()
# Override action if intervention is active
if is_intervention and teleop_action is not None:
if isinstance(teleop_action, dict):
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get(f"{ACTION}.delta_x", 0.0),
teleop_action.get(f"{ACTION}.delta_y", 0.0),
teleop_action.get(f"{ACTION}.delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
elif isinstance(teleop_action, np.ndarray):
action_list = teleop_action.tolist()
else:
action_list = teleop_action
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
self.terminate_on_success and success
)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info[TeleopEvents.IS_INTERVENTION] = is_intervention
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
info[TeleopEvents.SUCCESS] = success
new_transition[TransitionKey.INFO] = info
# Update complementary data with teleop action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
"terminate_on_success": self.terminate_on_success,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessorStep(ProcessorStep):
"""Apply reward classification to image observations."""
pretrained_path: str | None = None
device: str = "cpu"
success_threshold: float = 0.5
success_reward: float = 1.0
terminate_on_success: bool = True
reward_classifier: Any = None
def __post_init__(self):
"""Initialize the reward classifier after dataclass initialization."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
self.reward_classifier.to(self.device)
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy()
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
return new_transition
# Extract images from observation
images = {key: value for key, value in observation.items() if "image" in key}
if not images:
return new_transition
# Run reward classifier
start_time = time.perf_counter()
with torch.inference_mode():
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
classifier_frequency = 1 / (time.perf_counter() - start_time)
# Calculate reward and termination
reward = new_transition.get(TransitionKey.REWARD, 0.0)
terminated = new_transition.get(TransitionKey.DONE, False)
if math.isclose(success, 1, abs_tol=1e-2):
reward = self.success_reward
if self.terminate_on_success:
terminated = True
# Update transition
new_transition[TransitionKey.REWARD] = reward
new_transition[TransitionKey.DONE] = terminated
# Update info with classifier frequency
info = new_transition.get(TransitionKey.INFO, {})
info["reward_classifier_frequency"] = classifier_frequency
new_transition[TransitionKey.INFO] = info
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"device": self.device,
"success_threshold": self.success_threshold,
"success_reward": self.success_reward,
"terminate_on_success": self.terminate_on_success,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -0,0 +1,109 @@
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_STATE
from lerobot.processor.pipeline import (
ObservationProcessorStep,
ProcessorStepRegistry,
)
from lerobot.robots import Robot
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessorStep(ObservationProcessorStep):
"""Add joint velocity information to observations."""
dt: float = 0.1
last_joint_positions: torch.Tensor | None = None
def observation(self, observation: dict) -> dict:
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get(OBS_STATE)
if current_positions is None:
# TODO(steven): if we get here, then the transform_features method will not hold
raise ValueError(f"{OBS_STATE} is not in observation")
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
joint_velocities = torch.zeros_like(current_positions)
else:
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation[OBS_STATE] = extended_state
return new_observation
def get_config(self) -> dict[str, Any]:
return {
"dt": self.dt,
}
def reset(self) -> None:
self.last_joint_positions = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if OBS_STATE in features:
original_feature = features[OBS_STATE]
# Double the shape to account for positions + velocities
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessorStep(ObservationProcessorStep):
"""Add motor current information to observations."""
robot: Robot | None = None
def observation(self, observation: dict) -> dict:
# Get current values from robot state
if self.robot is None:
raise ValueError("Robot is not set")
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get(OBS_STATE)
if current_state is None:
return observation
extended_state = torch.cat([current_state, motor_currents], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation[OBS_STATE] = extended_state
return new_observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
if OBS_STATE in features and self.robot is not None:
original_feature = features[OBS_STATE]
# Add motor current dimensions to the original state shape
num_motors = 0
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
if num_motors > 0:
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
return features
@@ -0,0 +1,503 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
This script:
1. Loads an existing pretrained policy model
2. Extracts normalization statistics from the model
3. Creates both preprocessor and postprocessor:
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
- Postprocessor: unnormalizes outputs (actions) for inference
4. Removes normalization layers from the model state_dict
5. Saves the new model and both processors
Usage:
python src/lerobot/processor/migrate_policy_normalization.py \
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
--policy-type act \
--push-to-hub
"""
import argparse
import importlib
import json
import os
from copy import deepcopy
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from .batch_processor import AddBatchDimensionProcessorStep
from .device_processor import DeviceProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from .pipeline import PolicyProcessorPipeline
from .rename_processor import RenameProcessorStep
# Policy type to class mapping
POLICY_CLASSES = {
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
}
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Extract normalization statistics from model state_dict."""
stats = {}
# Define patterns to match and their prefixes to remove
normalization_patterns = [
"normalize_inputs.buffer_",
"unnormalize_outputs.buffer_",
"normalize_targets.buffer_",
"normalize.", # Must come after normalize_* patterns
"unnormalize.", # Must come after unnormalize_* patterns
"input_normalizer.",
"output_normalizer.",
]
# Process each key in state_dict
for key, tensor in state_dict.items():
# Try each pattern
for pattern in normalization_patterns:
if key.startswith(pattern):
# Extract the remaining part after the pattern
remaining = key[len(pattern) :]
parts = remaining.split(".")
# Need at least feature name and stat type
if len(parts) >= 2:
# Last part is the stat type (mean, std, min, max, etc.)
stat_type = parts[-1]
# Everything else is the feature name
feature_name = ".".join(parts[:-1]).replace("_", ".")
# Add to stats
if feature_name not in stats:
stats[feature_name] = {}
stats[feature_name][stat_type] = tensor.clone()
# Only process the first matching pattern
break
return stats
def detect_features_and_norm_modes(
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
"""Detect features and normalization modes from config and stats."""
features = {}
norm_modes = {}
# First, check if there's a normalization_mapping in the config
if "normalization_mapping" in config:
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
# Extract normalization modes from config
for feature_name, mode_str in config["normalization_mapping"].items():
# Convert string to NormalizationMode enum
if mode_str == "mean_std":
mode = NormalizationMode.MEAN_STD
elif mode_str == "min_max":
mode = NormalizationMode.MIN_MAX
else:
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
continue
# Determine feature type from feature name
if "image" in feature_name or "visual" in feature_name:
feature_type = FeatureType.VISUAL
elif "state" in feature_name:
feature_type = FeatureType.STATE
elif "action" in feature_name:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
norm_modes[feature_type] = mode
# Try to extract from config
if "features" in config:
for key, feature_config in config["features"].items():
shape = feature_config.get("shape", feature_config.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
# Determine feature type
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE # Default
features[key] = PolicyFeature(feature_type, shape)
# If no features in config, infer from stats
if not features:
for key, stat_dict in stats.items():
# Get shape from any stat tensor
tensor = next(iter(stat_dict.values()))
shape = tuple(tensor.shape)
# Determine feature type based on key
if "image" in key or "visual" in key or "pixels" in key:
feature_type = FeatureType.VISUAL
elif "state" in key or "joint" in key or "position" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
features[key] = PolicyFeature(feature_type, shape)
# If normalization modes weren't in config, determine based on available stats
if not norm_modes:
for key, stat_dict in stats.items():
if key in features:
if "mean" in stat_dict and "std" in stat_dict:
feature_type = features[key].type
if feature_type not in norm_modes:
norm_modes[feature_type] = NormalizationMode.MEAN_STD
elif "min" in stat_dict and "max" in stat_dict:
feature_type = features[key].type
if feature_type not in norm_modes:
norm_modes[feature_type] = NormalizationMode.MIN_MAX
# Default normalization modes if not detected
if FeatureType.VISUAL not in norm_modes:
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
if FeatureType.STATE not in norm_modes:
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
if FeatureType.ACTION not in norm_modes:
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
return features, norm_modes
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Remove normalization layers from state_dict."""
new_state_dict = {}
# Patterns to remove
remove_patterns = [
"normalize_inputs.",
"unnormalize_outputs.",
"normalize_targets.", # Added pattern for target normalization
"normalize.",
"unnormalize.",
"input_normalizer.",
"output_normalizer.",
"normalizer.",
]
for key, tensor in state_dict.items():
should_remove = any(pattern in key for pattern in remove_patterns)
if not should_remove:
new_state_dict[key] = tensor
return new_state_dict
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert features from old format to PolicyFeature objects."""
converted_features = {}
for key, feature_dict in features_dict.items():
# Determine feature type based on key
if "image" in key or "visual" in key:
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
# Get shape from feature dict
shape = feature_dict.get("shape", feature_dict.get("dim"))
shape = (shape,) if isinstance(shape, int) else tuple(shape)
converted_features[key] = PolicyFeature(feature_type, shape)
return converted_features
def load_model_from_hub(
repo_id: str, revision: str = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
# Load config
with open(config_path) as f:
config = json.load(f)
with open(train_config_path) as f:
train_config = json.load(f)
return state_dict, config, train_config
def main():
parser = argparse.ArgumentParser(
description="Migrate policy models with normalization layers to new pipeline system"
)
parser.add_argument(
"--pretrained-path",
type=str,
required=True,
help="Path to pretrained model (hub repo or local directory)",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for migrated model (default: same as pretrained-path)",
)
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
parser.add_argument(
"--hub-repo-id",
type=str,
default=None,
help="Hub repository ID for pushing (default: same as pretrained-path)",
)
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
args = parser.parse_args()
# Load model and config
print(f"Loading model from {args.pretrained_path}...")
if os.path.isdir(args.pretrained_path):
# Local directory
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
else:
# Hub repository
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
# Extract normalization statistics
print("Extracting normalization statistics...")
stats = extract_normalization_stats(state_dict)
print(f"Found normalization statistics for: {list(stats.keys())}")
# Detect input features and normalization modes
print("Detecting features and normalization modes...")
features, norm_map = detect_features_and_norm_modes(config, stats)
print(f"Detected features: {list(features.keys())}")
print(f"Normalization modes: {norm_map}")
# Remove normalization layers from state_dict
print("Removing normalization layers from model...")
new_state_dict = remove_normalization_layers(state_dict)
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
if removed_keys:
print(f"Removed {len(removed_keys)} normalization layer keys")
# Determine output path
if args.output_dir:
output_dir = Path(args.output_dir)
else:
if os.path.isdir(args.pretrained_path):
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
else:
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
output_dir.mkdir(parents=True, exist_ok=True)
# Clean up config - remove normalization_mapping field
cleaned_config = dict(config)
if "normalization_mapping" in cleaned_config:
print("Removing 'normalization_mapping' field from config")
del cleaned_config["normalization_mapping"]
policy_type = deepcopy(cleaned_config["type"])
del cleaned_config["type"]
# Instantiate the policy model with cleaned config and load the cleaned state dict
print(f"Instantiating {policy_type} policy model...")
policy_class_path = POLICY_CLASSES[policy_type]
module_path, class_name = policy_class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
policy_class = getattr(module, class_name)
# Create config class instance
config_module_path = module_path.replace("modeling", "configuration")
config_module = importlib.import_module(config_module_path)
# Handle special cases for config class names
config_class_names = {
"act": "ACTConfig",
"diffusion": "DiffusionConfig",
"pi0": "PI0Config",
"pi0fast": "PI0FASTConfig",
"smolvla": "SmolVLAConfig",
"tdmpc": "TDMPCConfig",
"vqbet": "VQBeTConfig",
"sac": "SACConfig",
"classifier": "ClassifierConfig",
}
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
config_class = getattr(config_module, config_class_name)
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
if "input_features" not in cleaned_config:
raise ValueError("Missing mandatory 'input_features' in config")
if "output_features" not in cleaned_config:
raise ValueError("Missing mandatory 'output_features' in config")
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
# Create config instance from cleaned config dict
policy_config = config_class(**cleaned_config)
# Create policy instance - some policies expect dataset_stats
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Now create preprocessor and postprocessor with cleaned_config available
print("Creating preprocessor and postprocessor...")
# The pattern from existing processor factories:
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
# Get features from cleaned_config (now they're PolicyFeature objects)
input_features = cleaned_config.get("input_features", {})
output_features = cleaned_config.get("output_features", {})
# Create preprocessor with two normalizers (following the pattern from processor factories)
preprocessor_steps = [
RenameProcessorStep(rename_map={}),
NormalizerProcessorStep(
features={**input_features, **output_features},
norm_map=norm_map,
stats=stats,
),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=policy_config.device),
]
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
# Create postprocessor with unnormalizer for outputs only
postprocessor_steps = [
DeviceProcessorStep(device="cpu"),
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
]
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
# Determine hub repo ID if pushing to hub
if args.push_to_hub:
if args.hub_repo_id:
hub_repo_id = args.hub_repo_id
else:
if not os.path.isdir(args.pretrained_path):
# Use same repo with "_migrated" suffix
hub_repo_id = f"{args.pretrained_path}_migrated"
else:
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
else:
hub_repo_id = None
# Save preprocessor and postprocessor to root directory
print(f"Saving preprocessor to {output_dir}...")
preprocessor.save_pretrained(output_dir)
if args.push_to_hub:
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
print(f"Saving postprocessor to {output_dir}...")
postprocessor.save_pretrained(output_dir)
if args.push_to_hub:
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
# Save model using the policy's save_pretrained method
print(f"Saving model to {output_dir}...")
policy.save_pretrained(
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
)
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
tags = set(tags).union({"robotics", "lerobot", policy_type})
tags = list(tags)
# Generate model card
card = policy.generate_model_card(
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
)
# Save model card locally
card.save(str(output_dir / "README.md"))
print(f"Model card saved to {output_dir / 'README.md'}")
# Push model card to hub if requested
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_dir / "README.md"),
path_in_repo="README.md",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Add model card for migrated model",
)
print("Model card pushed to hub")
print("\nMigration complete!")
print(f"Migrated model saved to: {output_dir}")
if args.push_to_hub:
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
if __name__ == "__main__":
main()
+207 -266
View File
@@ -1,179 +1,84 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
"""Convert numpy arrays and other types to torch tensors."""
tensor_stats: dict[str, dict[str, Tensor]] = {}
for key, sub in stats.items():
tensor_stats[key] = {}
for stat_name, value in sub.items():
if isinstance(value, np.ndarray):
tensor_val = torch.from_numpy(value.astype(np.float32))
elif isinstance(value, torch.Tensor):
tensor_val = value.to(dtype=torch.float32)
elif isinstance(value, (int, float, list, tuple)):
tensor_val = torch.tensor(value, dtype=torch.float32)
else:
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
tensor_stats[key][stat_name] = tensor_val
return tensor_stats
from .converters import to_tensor
from .core import EnvTransition, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessor:
"""Normalizes observations and actions in a single processor step.
class _NormalizationMixin:
"""
A mixin class providing core functionality for normalization and unnormalization.
This processor handles normalization of both observation and action tensors
using either mean/std normalization or min/max scaling to a [-1, 1] range.
For each tensor key in the stats dictionary, the processor will:
- Use mean/std normalization if those statistics are provided: (x - mean) / std
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
The processor can be configured to normalize only specific keys by setting
the normalize_keys parameter.
This class manages normalization statistics, their conversion to tensors, device placement,
and the application of normalization transformations. It is designed to be inherited by
concrete ProcessorStep implementations.
"""
# Features and normalisation map are mandatory to match the design of normalize.py
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: dict[str, dict[str, Any]] | None = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: set[str] | None = None
device: torch.device | str | None = None
eps: float = 1e-8
normalize_observation_keys: set[str] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_keys: set[str] | None = None,
eps: float = 1e-8,
) -> NormalizerProcessor:
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
pattern used in normalize.py.
"""
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_keys=normalize_keys,
eps=eps,
)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
# Robust JSON deserialization handling (guard empty maps)
if self.features:
first_val = next(iter(self.features.values()))
if isinstance(first_val, dict):
reconstructed = {}
for key, ft_dict in self.features.items():
reconstructed[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
if self.norm_map:
# if keys are strings (JSON), rebuild enum map
if all(isinstance(k, str) for k in self.norm_map.keys()):
reconstructed = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
# Convert stats to tensors and move to the target device once during initialization.
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
self._tensor_stats = to_tensor(self.stats, device=self.device)
# Ensure *normalize_keys* is a set for fast look-ups and compare by
# value later when returning the configuration.
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def to(self, device: torch.device | str) -> _NormalizationMixin:
"""Moves the processor's normalization stats to the specified device and returns self."""
self.device = device
self._tensor_stats = to_tensor(self.stats, device=self.device)
return self
def _normalize_obs(self, observation):
if observation is None:
return None
def state_dict(self) -> dict[str, Tensor]:
flat: dict[str, Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
return flat
# Decide which keys should be normalised for this call.
if self.normalize_keys is not None:
keys_to_norm = self.normalize_keys
else:
# Use feature map to skip action keys.
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
def load_state_dict(self, state: dict[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
# Load to the processor's configured device.
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
return processed
def _normalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
config = {
@@ -183,45 +88,87 @@ class NormalizerProcessor:
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
if self.normalize_observation_keys is not None:
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
return config
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
new_observation = dict(observation)
for key, feature in self.features.items():
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
continue
if feature.type != FeatureType.ACTION and key in new_observation:
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
return new_observation
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
tensor = torch.as_tensor(action, dtype=torch.float32)
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
return processed_action
def reset(self):
pass
def _apply_transform(
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
) -> Tensor:
"""Core logic to apply normalization or unnormalization."""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
return tensor
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# Ensure input tensor is on the same device as the stats.
if self.device and tensor.device != self.device:
tensor = tensor.to(self.device)
# For Accelerate compatibility: move stats to match input tensor device
input_device = tensor.device
stats = self._tensor_stats[key]
tensor = tensor.to(dtype=torch.float32)
# Move stats to input device if needed
stats_device = next(iter(stats.values())).device
if stats_device != input_device:
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
if inverse:
return tensor * std + mean
return (tensor - mean) / denom
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
denom = max_val - min_val
# When min_val == max_val, substitute the denominator with a small epsilon
# to prevent division by zero. This consistently maps an input equal to
# min_val to -1, ensuring a stable transformation.
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
)
if inverse:
# Map from [-1, 1] back to [min, max]
return (tensor + 1) / 2 * denom + min_val
# Map from [min, max] to [-1, 1]
return 2 * (tensor - min_val) / denom - 1
# If necessary stats are missing, return input unchanged.
return tensor
@dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessor:
"""Inverse normalisation for observations and actions.
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
transform.
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
"""
A processor that applies normalization to observations and actions in a transition.
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
This class directly implements the normalization logic for both observation and action
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
initialization.
"""
@classmethod
def from_lerobot_dataset(
@@ -229,103 +176,97 @@ class UnnormalizerProcessor:
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
return processed
def _unnormalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
*,
normalize_observation_keys: set[str] | None = None,
eps: float = 1e-8,
device: torch.device | str | None = None,
) -> NormalizerProcessorStep:
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_observation_keys=normalize_observation_keys,
eps=eps,
device=device,
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Handle observation normalization.
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
observation, inverse=False
)
# Handle action normalization.
action = new_transition.get(TransitionKey.ACTION)
if action is not None:
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def reset(self):
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
"""
A processor that applies unnormalization (the inverse of normalization) to
observations and actions in a transition.
This is typically used to transform actions from a normalized policy output back into
the original scale for execution in an environment.
"""
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
device: torch.device | str | None = None,
) -> UnnormalizerProcessorStep:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy()
# Handle observation unnormalization.
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
# Handle action unnormalization.
action = new_transition.get(TransitionKey.ACTION)
if action is not None:
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def hotswap_stats(
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
) -> PolicyProcessorPipeline:
"""
Replaces normalization statistics in a PolicyProcessor pipeline.
This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it.
It's useful for adapting a trained policy to a new environment or dataset with
different data distributions.
"""
rp = deepcopy(policy_processor)
for step in rp.steps:
if isinstance(step, _NormalizationMixin):
step.stats = stats
# Re-initialize tensor_stats on the correct device.
step._tensor_stats = to_tensor(stats, device=step.device)
return rp
@@ -22,12 +22,13 @@ from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor(ObservationProcessor):
class VanillaObservationProcessorStep(ObservationProcessorStep):
"""
Processes environment observations into the LeRobot format by handling both images and states.
@@ -106,9 +107,8 @@ class VanillaObservationProcessor(ObservationProcessor):
def observation(self, observation):
return self._process_observation(observation)
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
File diff suppressed because it is too large Load Diff
+16 -6
View File
@@ -13,19 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor(ObservationProcessor):
class RenameProcessorStep(ObservationProcessorStep):
"""Rename processor that renames keys in the observation."""
rename_map: dict[str, str] = field(default_factory=dict)
@@ -43,9 +42,20 @@ class RenameProcessor(ObservationProcessor):
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.
"""
return {self.rename_map.get(k, k): v for k, v in features.items()}
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
if not stats:
return {}
renamed: dict[str, dict[str, Any]] = {}
for old_key, sub_stats in stats.items():
new_key = rename_map.get(old_key, old_key)
renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
return renamed
@@ -0,0 +1,246 @@
"""
Tokenizer processor for handling text tokenization in robot transitions.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
else:
AutoTokenizer = None
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessorStep(ObservationProcessorStep):
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
The processor supports both single strings and lists of strings as task inputs.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128)
```
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
input_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.input_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoTokenizer is None:
raise ImportError("AutoTokenizer is not available")
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
Args:
transition: Input transition containing complementary_data.
Returns:
List of task strings if task is present, None otherwise.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
if self.task_key not in complementary_data:
return None
task = complementary_data[self.task_key]
if task is None:
return None
# Convert to list of strings
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
return task
return None
def observation(self, observation):
"""Process the transition by tokenizing the task text.
Args:
transition: Input transition containing complementary_data with task text.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
"""
task = self.get_task(self.transition)
if task is None:
return observation
# Tokenize the task (creates CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
target_device = self._detect_device(self.transition)
# Move tokenized tensors to match the device of other data
if target_device is not None:
tokenized_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_prompt.items()
}
# Get or create observation dict
new_observation = dict(observation)
# Add tokenized data to observation
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
This allows the tokenized tensors to match the device of other data,
which is especially important for multi-GPU training with Accelerate.
Args:
transition: The transition to search for existing tensors.
Returns:
The device of the first tensor found, or None if no tensors exist.
"""
# Check observation tensors first (most likely to exist)
observation = transition.get(TransitionKey.OBSERVATION)
if observation:
for value in observation.values():
if isinstance(value, torch.Tensor):
return value.device
# Check action tensor
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
return action.device
return None # No tensors found, keep on CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
Args:
text: Text string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
"""
return self.input_tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
padding=self.padding,
padding_side=self.padding_side,
return_tensors="pt",
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
"""
config = {
"max_length": self.max_length,
"task_key": self.task_key,
"padding_side": self.padding_side,
"padding": self.padding,
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
Args:
features: Input feature dictionary.
Returns:
Updated feature dictionary with tokenized task features added.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
if OBS_LANGUAGE_TOKENS not in features:
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if OBS_LANGUAGE_ATTENTION_MASK not in features:
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
return features
+162 -30
View File
@@ -59,7 +59,7 @@ lerobot-record \
import logging
import time
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from pathlib import Path
from pprint import pformat
@@ -72,10 +72,23 @@ from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import (
IdentityProcessorStep,
PolicyProcessorPipeline,
RobotProcessorPipeline,
TransitionKey,
)
from lerobot.processor.converters import (
action_to_transition,
observation_to_transition,
transition_to_dataset_frame,
transition_to_robot_action,
)
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -149,6 +162,8 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
@@ -187,6 +202,36 @@ class RecordConfig:
return ["policy"]
""" --------------- record_loop() data flow --------------------------
[ Robot ]
V
[ robot.get_observation() ] ---> raw_obs
V
[ robot_observation_processor ] ---> obs_transition
V
.-----( ACTION LOGIC )------------------.
V V
[ From Teleoperator ] [ From Policy ]
| |
| [teleop.get_action] -> raw_action | [predict_action]
| | | |
| V | V
| [teleop_action_processor] | |
| | | |
'---> teleop_transition '---> policy_transition
| |
'-------------------------.-------------'
V
[ robot_action_processor ] --> robot_action_to_send
V
[ robot.send_action() ] -- (Robot Executes)
V
( Transitions are merged & added to Dataset )
V
( Rerun Log / Loop Wait )
"""
@safe_stop_image_writer
def record_loop(
robot: Robot,
@@ -195,15 +240,32 @@ def record_loop(
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
control_time_s: int | None = None,
teleop_action_processor: RobotProcessorPipeline | None = None, # runs after teleop
robot_action_processor: RobotProcessorPipeline | None = None, # runs before robot
robot_observation_processor: RobotProcessorPipeline | None = None, # runs after robot
single_task: str | None = None,
display_data: bool = False,
):
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
)
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=lambda tr: tr, to_output=transition_to_robot_action
)
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=lambda tr: tr,
)
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
teleop_arm = teleop_keyboard = None
if isinstance(teleop, list):
if isinstance(teleop, list): # For LeKiwi
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
teleop_arm = next(
(
@@ -219,9 +281,20 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
)
# if policy is given it needs cleaning up
if policy is not None:
# Reset policy and processor if they are provided
if policy is not None and preprocessor is not None and postprocessor is not None:
policy.reset()
preprocessor.reset()
postprocessor.reset()
# Reset custom pipelines
teleop_action_processor.reset()
robot_action_processor.reset()
robot_observation_processor.reset()
policy_transition = None
teleop_transition = None
obs_transition = None
timestamp = 0
start_episode_t = time.perf_counter()
@@ -232,51 +305,88 @@ def record_loop(
events["exit_early"] = False
break
observation = robot.get_observation()
# Get robot observation
obs = robot.get_observation()
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
obs_transition = robot_observation_processor(obs)
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
if dataset is not None:
observation_frame = transition_to_dataset_frame(
obs_transition, dataset.features
) # Convert the observation to the dataset format
if policy is not None:
action_values = predict_action(
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
elif policy is None and isinstance(teleop, Teleoperator):
action = teleop.get_action()
elif policy is None and isinstance(teleop, list):
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
action_names = dataset.features["action"]["names"]
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
policy_transition = {
TransitionKey.ACTION: policy_action,
TransitionKey.COMPLEMENTARY_DATA: {},
}
elif isinstance(teleop, Teleoperator):
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
teleop_transition = teleop_action_processor(act)
elif isinstance(teleop, list):
arm_action = teleop_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
keyboard_action = teleop_keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_action)
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
teleop_transition = teleop_action_processor(act)
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
"No policy or teleoperator provided, skipping action generation. "
"This is likely to happen when resetting the environment without a teleop device."
"The robot won't be at its rest position at the start of the next episode."
)
continue
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
# Applies a pipeline to the action, default is IdentityProcessor
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
if policy is not None and policy_transition is not None:
robot_action_to_send = robot_action_processor(policy_transition)
else:
robot_action_to_send = robot_action_processor(teleop_transition)
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_ = robot.send_action(robot_action_to_send)
# Write to dataset
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
# If transition_to_dataset_frame is provided, use it to merge the transitions.
merged = []
if obs_transition is not None: # The observation from the robot
merged.append(obs_transition)
if teleop_transition is not None: # The action from teleop
merged.append(teleop_transition)
if policy_transition is not None: # The action from policy
merged.append(policy_transition)
frame = transition_to_dataset_frame(
merged if len(merged) > 1 else merged[0], dataset.features
) # Convert the observation to the dataset format
dataset.add_frame(frame, task=single_task)
if display_data:
log_rerun_data(observation, action)
log_rerun_data([obs_transition, teleop_transition or policy_transition])
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
@@ -296,7 +406,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
dataset_features = {**action_features, **obs_features}
# Add next.* features that are generated during recording
transition_features = {
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
}
dataset_features = {**action_features, **obs_features, **transition_features}
if cfg.resume:
dataset = LeRobotDataset(
@@ -328,6 +446,18 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
if teleop is not None:
@@ -345,6 +475,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
+20 -4
View File
@@ -45,9 +45,10 @@ from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import draccus
from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -83,13 +84,25 @@ class ReplayConfig:
dataset: DatasetReplayConfig
# Use vocal synthesis to read events.
play_sounds: bool = True
# Optional processor for actions before sending to robot
robot_action_processor: RobotProcessorPipeline | None = None
@draccus.wrap()
@parser.wrap()
def replay(cfg: ReplayConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
# Initialize robot action processor with default if not provided
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=action_to_transition,
to_output=transition_to_robot_action, # type: ignore[arg-type]
)
# Reset processor
robot_action_processor.reset()
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns("action")
@@ -104,7 +117,10 @@ def replay(cfg: ReplayConfig):
for i, name in enumerate(dataset.features["action"]["names"]):
action[name] = action_array[i]
robot.send_action(action)
# Process action through robot action processor
processed_action = robot_action_processor(action)
robot.send_action(processed_action) # type: ignore[arg-type]
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / dataset.fps - dt_s)
@@ -14,6 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
from .config_so100_follower import SO100FollowerConfig
from .so100_follower import SO100Follower
from .so100_follower_end_effector import SO100FollowerEndEffector
@@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig):
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
@RobotConfig.register_subclass("so100_follower_end_effector")
@dataclass
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
# Path to URDF file for kinematics
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
urdf_path: str | None = None
# End-effector frame name in the URDF
target_frame_name: str = "gripper_frame_link"
# Default bounds for the end-effector position (in meters)
end_effector_bounds: dict[str, list[float]] = field(
default_factory=lambda: {
"min": [-1.0, -1.0, -1.0], # min x, y, z
"max": [1.0, 1.0, 1.0], # max x, y, z
}
)
max_gripper_pos: float = 50
end_effector_step_sizes: dict[str, float] = field(
default_factory=lambda: {
"x": 0.02,
"y": 0.02,
"z": 0.02,
}
)
@@ -0,0 +1,460 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import numpy as np
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
ActionProcessorStep,
ComplementaryDataProcessorStep,
EnvTransition,
ObservationProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.robots.robot import Robot
from lerobot.utils.rotation import Rotation
@ProcessorStepRegistry.register("ee_reference_and_delta")
@dataclass
class EEReferenceAndDelta(ActionProcessorStep):
"""
Compute the desired end-effector pose from the target pose and the current pose.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
"""
kinematics: RobotKinematics
end_effector_step_sizes: dict
motor_names: list[str]
use_latched_reference: bool = (
True # If True, latch reference on enable; if False, always use current pose
)
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
_prev_enabled: bool = field(default=False, init=False, repr=False)
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, action):
new_action = action.copy()
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
)
if "reference_joint_positions" in comp:
q = comp["reference_joint_positions"]
else:
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
# Current pose from FK on measured joints
t_curr = self.kinematics.forward_kinematics(q)
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
desired = None
if enabled:
ref = t_curr
if self.use_latched_reference:
# Latched reference mode: latch reference at the rising edge
if not self._prev_enabled or self.reference_ee_pose is None:
self.reference_ee_pose = t_curr.copy()
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
delta_p = np.array(
[
tx * self.end_effector_step_sizes["x"],
ty * self.end_effector_step_sizes["y"],
tz * self.end_effector_step_sizes["z"],
],
dtype=float,
)
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
desired = np.eye(4, dtype=float)
desired[:3, :3] = ref[:3, :3] @ r_abs
desired[:3, 3] = ref[:3, 3] + delta_p
self._command_when_disabled = desired.copy()
else:
# While disabled, keep sending the same command to avoid drift.
if self._command_when_disabled is None:
# If we've never had an enabled command yet, freeze current FK pose once.
self._command_when_disabled = t_curr.copy()
desired = self._command_when_disabled.copy()
# Write action fields
pos = desired[:3, 3]
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
new_action[f"{ACTION}.ee.x"] = float(pos[0])
new_action[f"{ACTION}.ee.y"] = float(pos[1])
new_action[f"{ACTION}.ee.z"] = float(pos[2])
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
self._prev_enabled = enabled
return new_action
def reset(self):
self._prev_enabled = False
self.reference_ee_pose = None
self._command_when_disabled = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.enabled", None)
features.pop(f"{ACTION}.target_x", None)
features.pop(f"{ACTION}.target_y", None)
features.pop(f"{ACTION}.target_z", None)
features.pop(f"{ACTION}.target_wx", None)
features.pop(f"{ACTION}.target_wy", None)
features.pop(f"{ACTION}.target_wz", None)
features[f"{ACTION}.ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@ProcessorStepRegistry.register("ee_bounds_and_safety")
@dataclass
class EEBoundsAndSafety(ActionProcessorStep):
"""
Clip the end-effector pose to the bounds and check for jumps.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
"""
end_effector_bounds: dict
max_ee_step_m: float = 0.05
max_ee_twist_step_rad: float = 0.20
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict) -> dict:
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
z = act.get(f"{ACTION}.ee.z", None)
wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
raise ValueError(
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
)
pos = np.array([x, y, z], dtype=float)
twist = np.array([wx, wy, wz], dtype=float)
# Clip position
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
# Check for jumps in position
if self._last_pos is not None:
dpos = pos - self._last_pos
n = float(np.linalg.norm(dpos))
if n > self.max_ee_step_m and n > 0:
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
self._last_pos = pos
self._last_twist = twist
act[f"{ACTION}.ee.x"] = float(pos[0])
act[f"{ACTION}.ee.y"] = float(pos[1])
act[f"{ACTION}.ee.z"] = float(pos[2])
act[f"{ACTION}.ee.wx"] = float(twist[0])
act[f"{ACTION}.ee.wy"] = float(twist[1])
act[f"{ACTION}.ee.wz"] = float(twist[2])
return act
def reset(self):
self._last_pos = None
self._last_twist = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# check if features as f"{ACTION}.ee.{x,y,z,wx,wy,wz}"
return features
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
@dataclass
class InverseKinematicsEEToJoints(ProcessorStep):
"""
Compute the desired joint positions from the desired end-effector pose.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
Output ACTION keys:
{
"action.joint_name_1.pos": float,
"action.joint_name_2.pos": float,
...
"action.joint_name_n.pos": float,
}
"""
kinematics: RobotKinematics
motor_names: list[str]
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
initial_guess_current_joints: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy()
act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
z = act.get(f"{ACTION}.ee.z", None)
wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return new_transition
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
)
if self.initial_guess_current_joints: # Use current joints as initial guess
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
else: # Use previous ik solution as initial guess
if self.q_curr is None:
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
# Build desired 4x4 transform from pos + rotvec (twist)
t_des = np.eye(4, dtype=float)
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
t_des[:3, 3] = [x, y, z]
# Compute inverse kinematics
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
self.q_curr = q_target
new_act = dict(act)
for i, name in enumerate(self.motor_names):
if name == "gripper":
# TODO(pepijn): Investigate if this is correct
# Do we want an observation key in the action field?
new_act[f"{ACTION}.gripper.pos"] = float(raw["gripper"])
else:
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
new_transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints:
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for name in self.motor_names:
features[f"{ACTION}.{name}.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
def reset(self):
self.q_curr = None
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
@dataclass
class GripperVelocityToJoint(ProcessorStep):
"""
Convert the gripper velocity to a joint velocity.
Input ACTION keys:
{
"action.gripper": float,
}
Output ACTION keys:
{
"action.gripper.pos": float,
}
"""
motor_names: list[str]
speed_factor: float = 20.0
clip_min: float = 0.0
clip_max: float = 100.0
discrete_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if f"{ACTION}.gripper" not in act:
raise ValueError(f"Required action key '{ACTION}.gripper' not found in transition")
if "gripper" not in self.motor_names:
raise ValueError(
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
)
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
gripper_action = gripper_action - 1.0
gripper_action *= self.clip_max
act[f"{ACTION}.gripper"] = gripper_action
# Get current gripper position from complementary data
raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity
u = float(act.get(f"{ACTION}.gripper", 0.0))
delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
new_act = dict(act)
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
new_act.pop(f"{ACTION}.gripper", None)
new_transition[TransitionKey.ACTION] = new_act
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.gripper", None)
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{OBS_STATE}.gripper.pos"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
return features
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
@dataclass
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
"""
Compute the end-effector pose from the joint positions.
Input OBSERVATION keys:
{
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
}
Output OBSERVATION keys:
{
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
}
"""
kinematics: RobotKinematics
motor_names: list[str]
def observation(self, obs: dict) -> dict:
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
return obs
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"{OBS_STATE}.ee.{k}"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
return features
@ProcessorStepRegistry.register("add_robot_observation")
@dataclass
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
"""
Read the robot's current observation and insert it into the transition as complementary data.
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
{ "<motor_name>": <float position>, ... }
"""
robot: Robot
def complementary_data(self, comp: dict | None) -> dict:
new_comp = dict(comp)
obs = (
self.robot.get_observation()
) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)?
new_comp["raw_joint_positions"] = {
k.removesuffix(".pos"): float(v)
for k, v in obs.items()
if isinstance(k, str) and k.endswith(".pos")
}
return new_comp
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -1,200 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
import numpy as np
from lerobot.cameras import make_cameras_from_configs
from lerobot.errors import DeviceNotConnectedError
from lerobot.model.kinematics import RobotKinematics
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.feetech import FeetechMotorsBus
from . import SO100Follower
from .config_so100_follower import SO100FollowerEndEffectorConfig
logger = logging.getLogger(__name__)
class SO100FollowerEndEffector(SO100Follower):
"""
SO100Follower robot with end-effector space control.
This robot inherits from SO100Follower but transforms actions from
end-effector space to joint space before sending them to the motors.
"""
config_class = SO100FollowerEndEffectorConfig
name = "so100_follower_end_effector"
def __init__(self, config: SO100FollowerEndEffectorConfig):
super().__init__(config)
self.bus = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
self.config = config
# Initialize the kinematics module for the so100 robot
if self.config.urdf_path is None:
raise ValueError(
"urdf_path must be provided in the configuration for end-effector control. "
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
)
self.kinematics = RobotKinematics(
urdf_path=self.config.urdf_path,
target_frame_name=self.config.target_frame_name,
)
# Store the bounds for end-effector position
self.end_effector_bounds = self.config.end_effector_bounds
self.current_ee_pos = None
self.current_joint_pos = None
@property
def action_features(self) -> dict[str, Any]:
"""
Define action features for end-effector control.
Returns dictionary with dtype, shape, and names.
"""
return {
"dtype": "float32",
"shape": (4,),
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""
Transform action from end-effector space to joint space and send to motors.
Args:
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
or a numpy array with [delta_x, delta_y, delta_z]
Returns:
The joint-space action that was sent to the motors
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Convert action to numpy array if not already
if isinstance(action, dict):
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
delta_ee = np.array(
[
action["delta_x"] * self.config.end_effector_step_sizes["x"],
action["delta_y"] * self.config.end_effector_step_sizes["y"],
action["delta_z"] * self.config.end_effector_step_sizes["z"],
],
dtype=np.float32,
)
if "gripper" not in action:
action["gripper"] = [1.0]
action = np.append(delta_ee, action["gripper"])
else:
logger.warning(
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
)
action = np.zeros(4, dtype=np.float32)
if self.current_joint_pos is None:
# Read current joint positions
current_joint_pos = self.bus.sync_read("Present_Position")
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
# Calculate current end-effector position using forward kinematics
if self.current_ee_pos is None:
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
# Set desired end-effector position by adding delta
desired_ee_pos = np.eye(4)
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
# Add delta to position and clip to bounds
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
if self.end_effector_bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3],
self.end_effector_bounds["min"],
self.end_effector_bounds["max"],
)
# Compute inverse kinematics to get joint positions
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
self.current_joint_pos, desired_ee_pos
)
# Create joint space action dictionary
joint_action = {
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
}
# Handle gripper separately if included in action
# Gripper delta action is in the range 0 - 2,
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
joint_action["gripper.pos"] = np.clip(
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
5,
self.config.max_gripper_pos,
)
self.current_ee_pos = desired_ee_pos.copy()
self.current_joint_pos = target_joint_values_in_degrees.copy()
self.current_joint_pos[-1] = joint_action["gripper.pos"]
# Send joint space action to parent class
return super().send_action(joint_action)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Read arm position
start = time.perf_counter()
obs_dict = self.bus.sync_read("Present_Position")
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[cam_key] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def reset(self):
self.current_ee_pos = None
self.current_joint_pos = None
+1 -4
View File
@@ -29,10 +29,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .so100_follower import SO100Follower
return SO100Follower(config)
elif config.type == "so100_follower_end_effector":
from .so100_follower import SO100FollowerEndEffector
return SO100FollowerEndEffector(config)
elif config.type == "so101_follower":
from .so101_follower import SO101Follower
@@ -69,6 +65,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
raise ValueError(config.type)
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
def ensure_safe_goal_position(
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
) -> dict[str, float]:
+74 -27
View File
@@ -62,9 +62,16 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.processor import TransitionKey
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.scripts.rl.gym_manipulator import (
create_transition,
make_processors,
make_robot_env,
step_env_and_process_transition,
)
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
@@ -91,7 +98,6 @@ from lerobot.utils.utils import (
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
@@ -236,7 +242,8 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(cfg=cfg.env)
online_env, teleop_device = make_robot_env(cfg=cfg.env)
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -257,6 +264,12 @@ def act_with_policy(
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -274,45 +287,71 @@ def act_with_policy(
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
policy_fps = policy_timer.fps_last
observation = {
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
}
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
else:
action = online_env.action_space.sample()
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action)
# Use the new step function
new_transition = step_env_and_process_transition(
env=online_env,
transition=transition,
action=action,
env_processor=env_processor,
action_processor=action_processor,
)
# Extract values from processed transition
next_observation = {
k: v
for k, v in new_transition[TransitionKey.OBSERVATION].items()
if k in cfg.policy.input_features
}
# Teleop action is the action that was executed in the environment
# It is either the action from the teleop device or the action from the policy
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
complementary_info = {
"discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
),
}
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
state=observation,
action=executed_action,
reward=reward,
next_state=next_obs,
next_state=next_observation,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
truncated=truncated,
complementary_info=complementary_info,
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
# Update transition for next iteration
transition = new_transition
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
@@ -347,12 +386,20 @@ def act_with_policy(
)
)
# Reset intervention counters
# Reset intervention counters and environment
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
File diff suppressed because it is too large Load Diff
+4 -5
View File
@@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl import learner_service
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
@@ -1048,10 +1049,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_cached_image_features(
next_observations, normalize=True
)
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
@@ -1176,7 +1175,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
TeleopEvents.IS_INTERVENTION
):
offline_replay_buffer.add(**transition)
+17 -6
View File
@@ -26,12 +26,13 @@ from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.scripts.eval import eval_policy
@@ -140,6 +141,9 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy,
ds_meta=dataset.meta,
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -149,6 +153,12 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
preprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
)
postprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -203,12 +213,9 @@ def train(cfg: TrainPipelineConfig):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,
policy,
@@ -240,7 +247,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
save_checkpoint(
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -284,6 +293,8 @@ def train(cfg: TrainPipelineConfig):
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id)
def main():
+77 -17
View File
@@ -56,11 +56,17 @@ import time
from dataclasses import asdict, dataclass
from pprint import pformat
import draccus
import rerun as rr
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
from lerobot.processor.converters import (
action_to_transition,
observation_to_transition,
transition_to_robot_action,
)
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -97,39 +103,84 @@ class TeleoperateConfig:
teleop_time_s: float | None = None
# Display all cameras on screen
display_data: bool = False
# Optional processors for data transformation
teleop_action_processor: RobotProcessorPipeline | None = None # runs after teleop
robot_action_processor: RobotProcessorPipeline | None = None # runs before robot
robot_observation_processor: RobotProcessorPipeline | None = None # runs after robot
def teleop_loop(
teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
teleop: Teleoperator,
robot: Robot,
fps: int,
display_data: bool = False,
duration: float | None = None,
teleop_action_processor: RobotProcessorPipeline | None = None,
robot_action_processor: RobotProcessorPipeline | None = None,
robot_observation_processor: RobotProcessorPipeline | None = None,
):
# Initialize processors with defaults if not provided
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
)
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=lambda tr: tr,
to_output=transition_to_robot_action, # type: ignore[arg-type]
)
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=lambda tr: tr,
)
# Reset processors
teleop_action_processor.reset()
robot_action_processor.reset()
robot_observation_processor.reset()
display_len = max(len(key) for key in robot.action_features)
start = time.perf_counter()
while True:
loop_start = time.perf_counter()
action = teleop.get_action()
if display_data:
observation = robot.get_observation()
log_rerun_data(observation, action)
robot.send_action(action)
# Get teleop action
raw_action = teleop.get_action()
# Process teleop action through pipeline
teleop_transition = teleop_action_processor(raw_action)
# Process action for robot through pipeline
robot_action_to_send = robot_action_processor(teleop_transition)
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
robot.send_action(robot_action_to_send) # type: ignore[arg-type]
if display_data:
# Get robot observation
obs = robot.get_observation()
# Process robot observation through pipeline
obs_transition = robot_observation_processor(obs)
log_rerun_data([obs_transition, teleop_transition])
print("\n" + "-" * (display_len + 10))
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
# Display the final robot action that was sent
for motor, value in robot_action_to_send.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
move_cursor_up(len(robot_action_to_send) + 5)
dt_s = time.perf_counter() - loop_start
busy_wait(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start
print("\n" + "-" * (display_len + 10))
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
for motor, value in action.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
if duration is not None and time.perf_counter() - start >= duration:
return
move_cursor_up(len(action) + 5)
@draccus.wrap()
@parser.wrap()
def teleoperate(cfg: TeleoperateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
@@ -143,7 +194,16 @@ def teleoperate(cfg: TeleoperateConfig):
robot.connect()
try:
teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
teleop_loop(
teleop=teleop,
robot=robot,
fps=cfg.fps,
display_data=cfg.display_data,
duration=cfg.teleop_time_s,
teleop_action_processor=cfg.teleop_action_processor,
robot_action_processor=cfg.robot_action_processor,
robot_observation_processor=cfg.robot_observation_processor,
)
except KeyboardInterrupt:
pass
finally:
+1 -1
View File
@@ -16,4 +16,4 @@
from .config import TeleoperatorConfig
from .teleoperator import Teleoperator
from .utils import make_teleoperator_from_config
from .utils import TeleopEvents, make_teleoperator_from_config
@@ -16,6 +16,8 @@
import logging
from ..utils import TeleopEvents
class InputController:
"""Base class for input controllers that generate motion deltas."""
@@ -134,10 +136,10 @@ class KeyboardController(InputController):
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
self.episode_end_status = "success"
self.episode_end_status = TeleopEvents.SUCCESS
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
self.episode_end_status = "failure"
self.episode_end_status = TeleopEvents.FAILURE
except AttributeError:
pass
@@ -255,13 +257,13 @@ class GamepadController(InputController):
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
self.episode_end_status = "success"
self.episode_end_status = TeleopEvents.SUCCESS
# A button (1) for failure
elif event.button == 1:
self.episode_end_status = "failure"
self.episode_end_status = TeleopEvents.FAILURE
# X button (0) for rerecord
elif event.button == 0:
self.episode_end_status = "rerecord_episode"
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
# RB button (6) for closing gripper
elif event.button == 6:
@@ -451,11 +453,11 @@ class GamepadControllerHID(InputController):
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
self.episode_end_status = "success"
self.episode_end_status = TeleopEvents.SUCCESS
elif buttons & 1 << 5:
self.episode_end_status = "failure"
self.episode_end_status = TeleopEvents.FAILURE
elif buttons & 1 << 4:
self.episode_end_status = "rerecord_episode"
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
else:
self.episode_end_status = None
@@ -21,6 +21,7 @@ from typing import Any
import numpy as np
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
from .configuration_gamepad import GamepadTeleopConfig
@@ -93,9 +94,9 @@ class GamepadTeleop(Teleoperator):
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
action_dict = {
"delta_x": gamepad_action[0],
"delta_y": gamepad_action[1],
"delta_z": gamepad_action[2],
"action.delta_x": gamepad_action[0],
"action.delta_y": gamepad_action[1],
"action.delta_z": gamepad_action[2],
}
# Default gripper action is to stay
@@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator):
return action_dict
def get_teleop_events(self) -> dict[str, Any]:
"""
Get extra control events from the gamepad such as intervention status,
episode termination, success indicators, etc.
Returns:
Dictionary containing:
- is_intervention: bool - Whether human is currently intervening
- terminate_episode: bool - Whether to terminate the current episode
- success: bool - Whether the episode was successful
- rerecord_episode: bool - Whether to rerecord the episode
"""
if self.gamepad is None:
return {
TeleopEvents.IS_INTERVENTION: False,
TeleopEvents.TERMINATE_EPISODE: False,
TeleopEvents.SUCCESS: False,
TeleopEvents.RERECORD_EPISODE: False,
}
# Update gamepad state to get fresh inputs
self.gamepad.update()
# Check if intervention is active
is_intervention = self.gamepad.should_intervene()
# Get episode end status
episode_end_status = self.gamepad.get_episode_end_status()
terminate_episode = episode_end_status in [
TeleopEvents.RERECORD_EPISODE,
TeleopEvents.FAILURE,
]
success = episode_end_status == TeleopEvents.SUCCESS
rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE
return {
TeleopEvents.IS_INTERVENTION: is_intervention,
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
TeleopEvents.SUCCESS: success,
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
}
def disconnect(self) -> None:
"""Disconnect from the gamepad."""
if self.gamepad is not None:
@@ -24,6 +24,7 @@ from typing import Any
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
PYNPUT_AVAILABLE = True
@@ -167,25 +168,15 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
return {
"dtype": "float32",
"shape": (4,),
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2, "action.gripper": 3},
}
else:
return {
"dtype": "float32",
"shape": (3,),
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
}
def _on_press(self, key):
if hasattr(key, "char"):
key = key.char
self.event_queue.put((key, True))
def _on_release(self, key):
if hasattr(key, "char"):
key = key.char
self.event_queue.put((key, False))
def get_action(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(
@@ -226,12 +217,75 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
self.current_pressed.clear()
action_dict = {
"delta_x": delta_x,
"delta_y": delta_y,
"delta_z": delta_z,
"action.delta_x": delta_x,
"action.delta_y": delta_y,
"action.delta_z": delta_z,
}
if self.config.use_gripper:
action_dict["gripper"] = gripper_action
return action_dict
def get_teleop_events(self) -> dict[str, Any]:
"""
Get extra control events from the keyboard such as intervention status,
episode termination, success indicators, etc.
Keyboard mappings:
- Any movement keys pressed = intervention active
- 's' key = success (terminate episode successfully)
- 'r' key = rerecord episode (terminate and rerecord)
- 'q' key = quit episode (terminate without success)
Returns:
Dictionary containing:
- is_intervention: bool - Whether human is currently intervening
- terminate_episode: bool - Whether to terminate the current episode
- success: bool - Whether the episode was successful
- rerecord_episode: bool - Whether to rerecord the episode
"""
if not self.is_connected:
return {
TeleopEvents.IS_INTERVENTION: False,
TeleopEvents.TERMINATE_EPISODE: False,
TeleopEvents.SUCCESS: False,
TeleopEvents.RERECORD_EPISODE: False,
}
# Check if any movement keys are currently pressed (indicates intervention)
movement_keys = [
keyboard.Key.up,
keyboard.Key.down,
keyboard.Key.left,
keyboard.Key.right,
keyboard.Key.shift,
keyboard.Key.shift_r,
keyboard.Key.ctrl_r,
keyboard.Key.ctrl_l,
]
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
# Check for episode control commands from misc_keys_queue
terminate_episode = False
success = False
rerecord_episode = False
# Process any pending misc keys
while not self.misc_keys_queue.empty():
key = self.misc_keys_queue.get_nowait()
if key == "s":
success = True
elif key == "r":
terminate_episode = True
rerecord_episode = True
elif key == "q":
terminate_episode = True
success = False
return {
TeleopEvents.IS_INTERVENTION: is_intervention,
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
TeleopEvents.SUCCESS: success,
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
}
@@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_phone import PhoneConfig
from .phone import Phone
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
import numpy as np
from ..config import TeleoperatorConfig
class PhoneOS(Enum):
ANDROID = "android"
IOS = "ios"
@TeleoperatorConfig.register_subclass("phone")
@dataclass
class PhoneConfig(TeleoperatorConfig):
phone_os: PhoneOS = PhoneOS.IOS
camera_offset = np.array(
[0.0, -0.02, 0.04]
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
@@ -0,0 +1,97 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION
from lerobot.processor import ActionProcessorStep, ProcessorStepRegistry
from lerobot.teleoperators.phone.config_phone import PhoneOS
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
@dataclass
class MapPhoneActionToRobotAction(ActionProcessorStep):
"""
Map calibrated phone pose (actions) to the inputs for robot actions
Expected input ACTION keys:
{
"action.phone.enabled": bool,
"action.phone.pos": np.ndarray,
"action.phone.rot": Rotation,
"action.phone.raw_inputs": dict,
}
Output ACTION keys:
{
"action.enabled": bool,
"action.ee.{x,y,z,wx,wy,wz}" : float
"action.gripper": float,
}
"""
platform: PhoneOS
_enabled_prev: bool = field(default=False, init=False, repr=False)
def action(self, act: dict) -> dict:
# Pop them from the action
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
pos = act.pop(f"{ACTION}.phone.pos", None)
rot = act.pop(f"{ACTION}.phone.rot", None)
inputs = act.pop(f"{ACTION}.phone.raw_inputs", {})
if pos is None or rot is None:
raise ValueError("pos and rot must be present in action")
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
# Map certain inputs to certain actions
if self.platform == PhoneOS.IOS:
gripper = float(inputs.get("a3", 0.0))
else:
a = float(inputs.get("reservedButtonA", 0.0))
b = float(inputs.get("reservedButtonB", 0.0))
gripper = (
a - b
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
# For some actions we need to invert the axis
act[f"{ACTION}.enabled"] = enabled
act[f"{ACTION}.target_x"] = -pos[1] if enabled else 0.0
act[f"{ACTION}.target_y"] = pos[0] if enabled else 0.0
act[f"{ACTION}.target_z"] = pos[2] if enabled else 0.0
act[f"{ACTION}.target_wx"] = rotvec[1] if enabled else 0.0
act[f"{ACTION}.target_wy"] = rotvec[0] if enabled else 0.0
act[f"{ACTION}.target_wz"] = -rotvec[2] if enabled else 0.0
act[f"{ACTION}.gripper"] = gripper # Still send gripper action when disabled
return act
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.phone.enabled", None)
features.pop(f"{ACTION}.phone.pos", None)
features.pop(f"{ACTION}.phone.rot", None)
features.pop(f"{ACTION}.phone.raw_inputs", None)
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@@ -0,0 +1,358 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Docs:
# hebi: https://docs.hebi.us/tools.html#mobile-io
# teleop: https://github.com/SpesRobotics/teleop
import logging
import threading
import time
import hebi
import numpy as np
from teleop import Teleop
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.utils.rotation import Rotation
logger = logging.getLogger(__name__)
class BasePhone:
_enabled: bool = False
_calib_pos: np.ndarray | None = None
_calib_rot_inv: Rotation | None = None
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
self._calib_pos = pos.copy()
@property
def is_calibrated(self) -> bool:
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
@property
def action_features(self) -> dict[str, type]:
return {
"phone.pos": np.ndarray, # shape (3,)
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
"phone.enabled": bool,
}
@property
def feedback_features(self) -> dict[str, type]:
# No haptic or other feedback implemented yet
pass
def configure(self) -> None:
# No additional configuration required for phone teleop
pass
def send_feedback(self, feedback: dict[str, float]) -> None:
# We could add haptic feedback (vibrations) here, but it's not implemented yet
raise NotImplementedError
class IOSPhone(BasePhone, Teleoperator):
name = "ios_phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._group = None
@property
def is_connected(self) -> bool:
return self._group is not None
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
lookup = hebi.Lookup()
time.sleep(2.0)
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
if group is None:
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
self._group = group
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
self.calibrate()
def calibrate(self) -> None:
print(
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
)
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
position, rotation = self._wait_for_capture_trigger()
self._calib_pos = position.copy()
self._calib_rot_inv = rotation.inv()
self._enabled = False
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
while True:
has_pose, position, rotation, fb_pose = self._read_current_pose()
if not has_pose:
time.sleep(0.01)
continue
io = getattr(fb_pose, "io", None)
button_b = getattr(io, "b", None) if io is not None else None
button_b1_pressed = False
if button_b is not None:
button_b1_pressed = bool(button_b.get_int(1))
if button_b1_pressed:
return position, rotation
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
fbk = self._group.get_next_feedback()
pose = fbk[0]
ar_pos = getattr(pose, "ar_position", None)
ar_quat = getattr(pose, "ar_orientation", None)
if ar_pos is None or ar_quat is None:
return False, None, None, None
# HEBI provides orientation in w, x, y, z format.
# Scipy's Rotation expects x, y, z, w.
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
rot = Rotation.from_quat(quat_xyzw)
pos = ar_pos - rot.apply(self.config.camera_offset)
return True, pos, rot, pose
def get_action(self) -> dict:
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
if not has_pose or not self.is_calibrated:
return {}
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
raw_inputs: dict[str, float | int | bool] = {}
io = getattr(fb_pose, "io", None)
if io is not None:
bank_a, bank_b = io.a, io.b
if bank_a:
for ch in range(1, 9):
if bank_a.has_float(ch):
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
if bank_b:
for ch in range(1, 9):
if bank_b.has_int(ch):
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
enable = bool(raw_inputs.get("b1", 0))
# Rising edge then re-capture calibration immediately from current raw pose
if enable and not self._enabled:
self._reapply_position_calibration(raw_position)
# Apply calibration
pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rotation
self._enabled = enable
return {
"phone.pos": pos_cal,
"phone.rot": rot_cal,
"phone.raw_inputs": raw_inputs,
"phone.enabled": self._enabled,
}
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._group = None
class AndroidPhone(BasePhone, Teleoperator):
name = "android_phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._teleop = None
self._teleop_thread = None
self._latest_pose = None
self._latest_message = None
self._android_lock = threading.Lock()
@property
def is_connected(self) -> bool:
return self._teleop is not None
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
logger.info("Starting teleop stream for Android...")
self._teleop = Teleop()
self._teleop.subscribe(self._android_callback)
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
self._teleop_thread.start()
logger.info(f"{self} connected, teleop stream started.")
self.calibrate()
def calibrate(self) -> None:
print(
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
)
print("Touch and move on the WebXR page to capture this pose...\n")
pos, rot = self._wait_for_capture_trigger()
self._calib_pos = pos.copy()
self._calib_rot_inv = rot.inv()
self._enabled = False
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
while True:
with self._android_lock:
msg = self._latest_message or {}
if bool(msg.get("move", False)):
ok, pos, rot, _pose = self._read_current_pose()
if ok:
return pos, rot
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
with self._android_lock:
if self._latest_pose is None:
return False, None, None, None
p = self._latest_pose.copy()
pose = self._latest_pose
rot = Rotation.from_matrix(p[:3, :3])
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
return True, pos, rot, pose
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
with self._android_lock:
self._latest_pose = pose
self._latest_message = message
def get_action(self) -> dict:
ok, raw_pos, raw_rot, pose = self._read_current_pose()
if not ok or not self.is_calibrated:
return {}
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
raw_inputs: dict[str, float | int | bool] = {}
msg = self._latest_message or {}
raw_inputs["move"] = bool(msg.get("move", False))
raw_inputs["scale"] = float(msg.get("scale", 1.0))
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
enable = bool(raw_inputs.get("move", False))
# Rising edge then re-capture calibration immediately from current raw pose
if enable and not self._enabled:
self._reapply_position_calibration(raw_pos)
# Apply calibration
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
rot_cal = self._calib_rot_inv * raw_rot
self._enabled = enable
return {
"phone.pos": pos_cal,
"phone.rot": rot_cal,
"phone.raw_inputs": raw_inputs,
"phone.enabled": self._enabled,
}
def disconnect(self) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self._teleop = None
if self._teleop_thread and self._teleop_thread.is_alive():
self._teleop_thread.join(timeout=1.0)
self._teleop_thread = None
self._latest_pose = None
class Phone(Teleoperator):
"""
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
"""
config_class = PhoneConfig
name = "phone"
def __init__(self, config: PhoneConfig):
super().__init__(config)
self.config = config
self._phone_impl: Teleoperator
if self.config.phone_os == PhoneOS.IOS:
self._phone_impl = IOSPhone(config)
elif self.config.phone_os == PhoneOS.ANDROID:
self._phone_impl = AndroidPhone(config)
else:
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
@property
def is_connected(self) -> bool:
return self._phone_impl.is_connected
def connect(self) -> None:
return self._phone_impl.connect()
def calibrate(self) -> None:
return self._phone_impl.calibrate()
@property
def is_calibrated(self) -> bool:
return self._phone_impl.is_calibrated
@property
def action_features(self) -> dict[str, type]:
return self._phone_impl.action_features
@property
def feedback_features(self) -> dict[str, type]:
return self._phone_impl.feedback_features
def configure(self) -> None:
return self._phone_impl.configure()
def get_action(self) -> dict:
return self._phone_impl.get_action()
def send_feedback(self, feedback: dict[str, float]) -> None:
return self._phone_impl.send_feedback(feedback)
def disconnect(self) -> None:
return self._phone_impl.disconnect()
+12
View File
@@ -12,10 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from .config import TeleoperatorConfig
from .teleoperator import Teleoperator
class TeleopEvents(Enum):
"""Shared constants for teleoperator events across teleoperators."""
SUCCESS = "success"
FAILURE = "failure"
RERECORD_EPISODE = "rerecord_episode"
IS_INTERVENTION = "is_intervention"
TERMINATE_EPISODE = "terminate_episode"
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
if config.type == "keyboard":
from .keyboard import KeyboardTeleop
+7
View File
@@ -31,6 +31,7 @@ from termcolor import colored
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyProcessorPipeline, TransitionKey
from lerobot.robots import Robot
@@ -101,6 +102,8 @@ def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
use_amp: bool,
task: str | None = None,
robot_type: str | None = None,
@@ -122,10 +125,14 @@ def predict_action(
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = preprocessor(observation)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Remove batch dimension
action = action.squeeze(0)
+1
View File
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
_torch_available, _torch_version = is_package_available("torch", return_version=True)
_transformers_available = is_package_available("transformers")
_gym_xarm_available = is_package_available("gym_xarm")
_gym_aloha_available = is_package_available("gym_aloha")
_gym_pusht_available = is_package_available("gym_pusht")
+174
View File
@@ -0,0 +1,174 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom rotation utilities to replace scipy.spatial.transform.Rotation."""
import numpy as np
class Rotation:
"""
Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality.
Supports conversions between rotation vectors, rotation matrices, and quaternions.
"""
def __init__(self, quat: np.ndarray) -> None:
"""Initialize rotation from quaternion [x, y, z, w]."""
self._quat = np.asarray(quat, dtype=float)
# Normalize quaternion
norm = np.linalg.norm(self._quat)
if norm > 0:
self._quat = self._quat / norm
@classmethod
def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation":
"""
Create rotation from rotation vector using Rodrigues' formula.
Args:
rotvec: Rotation vector [x, y, z] where magnitude is angle in radians
Returns:
Rotation instance
"""
rotvec = np.asarray(rotvec, dtype=float)
angle = np.linalg.norm(rotvec)
if angle < 1e-8:
# For very small angles, use identity quaternion
quat = np.array([0.0, 0.0, 0.0, 1.0])
else:
axis = rotvec / angle
half_angle = angle / 2.0
sin_half = np.sin(half_angle)
cos_half = np.cos(half_angle)
# Quaternion [x, y, z, w]
quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half])
return cls(quat)
@classmethod
def from_matrix(cls, matrix: np.ndarray) -> "Rotation":
"""
Create rotation from 3x3 rotation matrix.
Args:
matrix: 3x3 rotation matrix
Returns:
Rotation instance
"""
matrix = np.asarray(matrix, dtype=float)
# Shepherd's method for converting rotation matrix to quaternion
trace = np.trace(matrix)
if trace > 0:
s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw
qw = 0.25 * s
qx = (matrix[2, 1] - matrix[1, 2]) / s
qy = (matrix[0, 2] - matrix[2, 0]) / s
qz = (matrix[1, 0] - matrix[0, 1]) / s
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx
qw = (matrix[2, 1] - matrix[1, 2]) / s
qx = 0.25 * s
qy = (matrix[0, 1] + matrix[1, 0]) / s
qz = (matrix[0, 2] + matrix[2, 0]) / s
elif matrix[1, 1] > matrix[2, 2]:
s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy
qw = (matrix[0, 2] - matrix[2, 0]) / s
qx = (matrix[0, 1] + matrix[1, 0]) / s
qy = 0.25 * s
qz = (matrix[1, 2] + matrix[2, 1]) / s
else:
s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz
qw = (matrix[1, 0] - matrix[0, 1]) / s
qx = (matrix[0, 2] + matrix[2, 0]) / s
qy = (matrix[1, 2] + matrix[2, 1]) / s
qz = 0.25 * s
quat = np.array([qx, qy, qz, qw])
return cls(quat)
@classmethod
def from_quat(cls, quat: np.ndarray) -> "Rotation":
"""
Create rotation from quaternion.
Args:
quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring)
This implementation expects [x, y, z, w] format
Returns:
Rotation instance
"""
return cls(quat)
def as_matrix(self) -> np.ndarray:
"""
Convert rotation to 3x3 rotation matrix.
Returns:
3x3 rotation matrix
"""
qx, qy, qz, qw = self._quat
# Compute rotation matrix from quaternion
return np.array(
[
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],
[2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)],
[2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)],
],
dtype=float,
)
def as_rotvec(self) -> np.ndarray:
"""
Convert rotation to rotation vector.
Returns:
Rotation vector [x, y, z] where magnitude is angle in radians
"""
qx, qy, qz, qw = self._quat
# Ensure qw is positive for unique representation
if qw < 0:
qx, qy, qz, qw = -qx, -qy, -qz, -qw
# Compute angle and axis
angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0))
sin_half_angle = np.sqrt(1.0 - qw * qw)
if sin_half_angle < 1e-8:
# For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz]
return 2.0 * np.array([qx, qy, qz])
# Extract axis and scale by angle
axis = np.array([qx, qy, qz]) / sin_half_angle
return angle * axis
def as_quat(self) -> np.ndarray:
"""
Get quaternion representation.
Returns:
Quaternion [x, y, z, w]
"""
return self._quat.copy()
+11 -1
View File
@@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.random_utils import load_rng_state, save_rng_state
@@ -74,6 +75,8 @@ def save_checkpoint(
policy: PreTrainedPolicy,
optimizer: Optimizer,
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -81,7 +84,9 @@ def save_checkpoint(
pretrained_model/
config.json # policy config
model.safetensors # policy weights
train_config.json # train config
train_config.json # train config
processor.json # processor config (if preprocessor provided)
step_*.safetensors # processor state files (if any)
training_state/
optimizer_param_groups.json # optimizer param groups
optimizer_state.safetensors # optimizer state
@@ -95,10 +100,15 @@ def save_checkpoint(
policy (PreTrainedPolicy): The policy to save.
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir)
if preprocessor is not None:
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
+86 -15
View File
@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import os
from typing import Any
import numpy as np
import rerun as rr
from lerobot.processor import EnvTransition, TransitionKey
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
"""Initializes the Rerun SDK for visualizing the control loop."""
@@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
rr.spawn(memory_limit=memory_limit)
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
for obs, val in observation.items():
if isinstance(val, float):
rr.log(f"observation.{obs}", rr.Scalar(val))
elif isinstance(val, np.ndarray):
if val.ndim == 1:
for i, v in enumerate(val):
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
def _is_scalar(x):
return (
isinstance(x, numbers.Real)
or isinstance(x, (np.integer, np.floating))
or (isinstance(x, np.ndarray) and x.ndim == 0)
)
def log_rerun_data(
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
*,
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
) -> None:
items = data if isinstance(data, list) else ([data] if data is not None else [])
obs = {} if observation is None else dict(observation)
act = {} if action is None else dict(action)
for idx, item in enumerate(items):
if not isinstance(item, dict):
continue
if any(isinstance(k, TransitionKey) for k in item.keys()):
o = item.get(TransitionKey.OBSERVATION) or {}
a = item.get(TransitionKey.ACTION) or {}
if isinstance(o, dict):
obs.update(o)
if isinstance(a, dict):
act.update(a)
continue
keys = list(item.keys())
has_obs = any(str(k).startswith("observation.") for k in keys)
has_act = any(str(k).startswith("action.") for k in keys)
if has_obs or has_act:
if has_obs:
obs.update(item)
if has_act:
act.update(item)
else:
# No prefixes: assume first is observation, second is action, others are observation
if idx == 0:
obs.update(item)
elif idx == 1:
act.update(item)
else:
rr.log(f"observation.{obs}", rr.Image(val), static=True)
for act, val in action.items():
if isinstance(val, float):
rr.log(f"action.{act}", rr.Scalar(val))
elif isinstance(val, np.ndarray):
for i, v in enumerate(val):
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
obs.update(item)
for k, v in obs.items():
if v is None:
continue
key = k if str(k).startswith("observation.") else f"observation.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else:
rr.log(key, rr.Image(arr), static=True)
for k, v in act.items():
if v is None:
continue
key = k if str(k).startswith("action.") else f"action.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalar(float(v)))
elif isinstance(v, np.ndarray):
if v.ndim == 1:
for i, vi in enumerate(v):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
else:
# Fall back to flattening higher-dimensional arrays
flat = v.flatten()
for i, vi in enumerate(flat):
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc
size 5104
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
size 33400
oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e
size 31672
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c
size 515400
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
size 33400
oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf
size 31672
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
size 47424
@@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.utils.random_utils import set_seed
@@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
dataset_stats = dataset.meta.stats
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
@@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
)
batch = next(iter(dataloader))
batch = preprocessor(batch)
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
@@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
actions = {}
for i in range(actions_queue):
unnormalized_action = policy.select_action(obs).contiguous()
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
actions[str(i)] = action_robot
return output_dict, grad_stats, param_stats, actions
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61
size 200
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
size 16904
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
size 164
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
size 35496
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d
size 200
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
size 16904
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
size 164
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
size 35496
+132
View File
@@ -0,0 +1,132 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_merge_simple_vectors():
g1 = {
"action": {
"dtype": "float32",
"shape": (2,),
"names": ["ee.x", "ee.y"],
}
}
g2 = {
"action": {
"dtype": "float32",
"shape": (2,),
"names": ["ee.y", "ee.z"],
}
}
out = combine_feature_dicts(g1, g2)
assert "action" in out
assert out["action"]["dtype"] == "float32"
# Names merged with preserved order and de-dupuplication
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
# Shape correctly recomputed from names length
assert out["action"]["shape"] == (3,)
def test_merge_multiple_groups_order_and_dedup():
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
out = combine_feature_dicts(g1, g2, g3)
assert out["action"]["names"] == ["a", "b", "c", "d"]
assert out["action"]["shape"] == (4,)
def test_non_vector_last_wins_for_images():
# Non-vector (images) with same name should be overwritten by the last image specified
g1 = {
"observation.images.front": {
"dtype": "image",
"shape": (3, 480, 640),
"names": ["channels", "height", "width"],
}
}
g2 = {
"observation.images.front": {
"dtype": "image",
"shape": (3, 720, 1280),
"names": ["channels", "height", "width"],
}
}
out = combine_feature_dicts(g1, g2)
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
assert out["observation.images.front"]["dtype"] == "image"
def test_dtype_mismatch_raises():
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
_ = combine_feature_dicts(g1, g2)
def test_non_dict_passthrough_last_wins():
g1 = {"misc": 123}
g2 = {"misc": 456}
out = combine_feature_dicts(g1, g2)
# For non-dict entries the last one wins
assert out["misc"] == 456
-55
View File
@@ -1,55 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
+6 -104
View File
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features
@@ -39,8 +39,8 @@ from lerobot.policies.factory import (
get_policy_class,
make_policy,
make_policy_config,
make_pre_post_processors,
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
@@ -151,6 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Check that we can make the policy object.
dataset = make_dataset(train_cfg)
preprocessor, _ = make_pre_post_processors(train_cfg.policy, None)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy)
@@ -224,6 +225,7 @@ def test_act_backbone_lr():
assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg)
preprocessor, _ = make_pre_post_processors(cfg.policy, None)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
@@ -263,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
def test_normalize(insert_temporal_dim):
"""
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
an exception when the forward pass is called without the stats having been provided.
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
expected.
"""
input_features = {
"observation.image": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 96, 96),
),
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
}
norm_map = {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
dataset_stats = {
"observation.image": {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
"min": torch.randn(3, 1, 1),
"max": torch.randn(3, 1, 1),
},
"observation.state": {
"mean": torch.randn(10),
"std": torch.randn(10),
"min": torch.randn(10),
"max": torch.randn(10),
},
"action": {
"mean": torch.randn(5),
"std": torch.randn(5),
"min": torch.randn(5),
"max": torch.randn(5),
},
}
bsize = 2
input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10),
}
output_batch = {
"action": torch.randn(bsize, 5),
}
if insert_temporal_dim:
tdim = 4
for key in input_batch:
# [2,3,96,96] -> [2,tdim,3,96,96]
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
for key in output_batch:
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
# test without stats
normalize = Normalize(input_features, norm_map, stats=None)
with pytest.raises(AssertionError):
normalize(input_batch)
# test with stats
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
normalize(input_batch)
# test loading pretrained models
new_normalize = Normalize(input_features, norm_map, stats=None)
new_normalize.load_state_dict(normalize.state_dict())
new_normalize(input_batch)
# test without stats
unnormalize = Unnormalize(output_features, norm_map, stats=None)
with pytest.raises(AssertionError):
unnormalize(output_batch)
# test with stats
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
unnormalize(output_batch)
# test loading pretrained models
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
new_unnormalize.load_state_dict(unnormalize.state_dict())
unnormalize(output_batch)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
@@ -464,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
is out of date. For example, some PyTorch versions have different randomness, see this PR:
https://github.com/huggingface/lerobot/pull/1127.
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
and you changed your processor, you might have to update the test artifact.
"""
+355
View File
@@ -0,0 +1,355 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ACT policy processor."""
import tempfile
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.processor_act import make_act_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
DeviceProcessorStep,
NormalizerProcessorStep,
RenameProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
def create_default_config():
"""Create a default ACT configuration for testing."""
config = ACTConfig()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
}
config.normalization_mapping = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
config.device = "cpu"
return config
def create_default_stats():
"""Create default dataset statistics for testing."""
return {
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)},
}
def test_make_act_processor_basic():
"""Test basic creation of ACT processor."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
# Check processor names
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
def test_act_processor_normalization():
"""Test that ACT processor correctly normalizes and unnormalizes data."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
# Check that data is normalized and batched
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
assert processed[TransitionKey.ACTION].shape == (1, 4)
# Process action through postprocessor
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is unnormalized
assert postprocessed[TransitionKey.ACTION].shape == (1, 4)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_act_processor_cuda():
"""Test ACT processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
# Check that data is on CUDA
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is back on CPU
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_act_processor_accelerate_scenario():
"""Test ACT processor in simulated Accelerate scenario (data already on GPU)."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
action = torch.randn(1, 4).to(device)
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
# Check that data stays on same GPU (not moved unnecessarily)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
assert processed[TransitionKey.ACTION].device == device
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_act_processor_multi_gpu():
"""Test ACT processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
# Simulate data on different GPU (like in multi-GPU training)
device = torch.device("cuda:1")
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
action = torch.randn(1, 4).to(device)
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
# Check that data stays on cuda:1 (not moved to cuda:0)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
assert processed[TransitionKey.ACTION].device == device
def test_act_processor_without_stats():
"""Test ACT processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_act_pre_post_processors(config, dataset_stats=None)
# Should still create processors, but normalization won't have stats
assert preprocessor is not None
assert postprocessor is not None
# Process should still work (but won't normalize without stats)
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed is not None
def test_act_processor_save_and_load():
"""Test saving and loading ACT processor."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation, action)
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
assert processed[TransitionKey.ACTION].shape == (1, 4)
def test_act_processor_device_placement_preservation():
"""Test that ACT processor preserves device placement correctly."""
config = create_default_config()
stats = create_default_stats()
# Test with CPU config
config.device = "cpu"
preprocessor, _ = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Process CPU data
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
assert processed[TransitionKey.ACTION].device.type == "cpu"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_act_processor_mixed_precision():
"""Test ACT processor with mixed precision (float16)."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Modify the device processor to use float16
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessorStep with one that uses float16
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
action = torch.randn(4, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
# Check that data is converted to float16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
assert processed[TransitionKey.ACTION].dtype == torch.float16
def test_act_processor_batch_consistency():
"""Test that ACT processor handles different batch sizes correctly."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test single sample (unbatched)
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
# Test already batched data
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
action_batched = torch.randn(8, 4)
transition_batched = create_transition(observation_batched, action_batched)
processed_batched = preprocessor(transition_batched)
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
assert processed_batched[TransitionKey.ACTION].shape[0] == 8

Some files were not shown because too many files have changed in this diff Show More