mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Integrate pipeline and add phone teleop (#1681)
* Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * fix(ci): temporary fix on dataset deps version * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * refactor(train): Update memory pinning logic for mps compatibility * feat: initial commit phone teleop * ugly delta control * use quaternion * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * Add debug + calib * cleanup * Add pipeline * fix int * Add record example * nit * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * cleaned up steps and integrated pipeline with feature_contract * refactor steps and robot to pipeline * cleanup pipeline * cleanup code further * make it run * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * Do some todos and cleanup * change feature_contract to dataset_features * use one method for conversion pipeline output to add_frame dict and use base processors where possible * Add back in and use record_loop * update todo * rename to_dataset_frame * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix reference frame * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * update data visualization * update teleop example * fix record bugs * Add replay * Not code * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * Add eval script * fix `q_curr` in InverseKinematicsEEToJoints to the IK solution * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * Fix eval and android gripper * add some tests * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * Cleanup pr * fix more git diff pr issues * add path as type in save_pretrained * small nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename test file * fix: make dataset_features/feature_contract is optional * fix tests * Encorperate pr feedback * clean up record.py * add ascii art, fix normal record * remove merge issues * fix merge * remove features * Add feedback PR * fix last 4 tests * remove features check * rename to transform_features * add transform_features * fix lekiwi eval and update eval api example --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
@@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig(
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/eval_<dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -559,6 +562,12 @@ _init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
@@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES):
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -11,12 +12,14 @@ NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<eval_dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -43,6 +46,12 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
@@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -38,7 +38,7 @@ while True:
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
log_rerun_data(observation, {**arm_action, **base_action})
|
||||
log_rerun_data(observation=observation, action={**arm_action, **base_action})
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_processor
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot with degrees
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -0,0 +1,215 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 10
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = merge_features(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -0,0 +1,106 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import to_output_robot_action
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
|
||||
# This method converts the action from the dataset to a transition for pipeline
|
||||
def action_to_transition(action: dict):
|
||||
act = {}
|
||||
|
||||
# EE pose
|
||||
for k in ("ee.x", "ee.y", "ee.z", "ee.wx", "ee.wy", "ee.wz"):
|
||||
if k in action:
|
||||
act[f"action.{k}"] = float(action[k])
|
||||
|
||||
# Gripper: your dataset has absolute position
|
||||
if "gripper.pos" in action:
|
||||
act["action.gripper.pos"] = float(action["gripper.pos"])
|
||||
|
||||
return {
|
||||
"observation": None,
|
||||
"action": act,
|
||||
"reward": None,
|
||||
"done": False,
|
||||
"truncated": False,
|
||||
"info": {},
|
||||
"complementary_data": {},
|
||||
}
|
||||
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specif
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone import Phone
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
phone_obs = teleop_device.get_action()
|
||||
if not phone_obs:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone to EE pose transition
|
||||
ee_transition = phone_to_robot_ee_pose(phone_obs)
|
||||
|
||||
# EE pose to Joints transition
|
||||
joint_action = robot_ee_to_joints(ee_transition)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
time.sleep(0.01)
|
||||
+3
-1
@@ -111,6 +111,7 @@ intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -152,7 +153,8 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]"
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: RobotProcessor,
|
||||
initial_features: dict[str, Any],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith("action."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len("action.") :]
|
||||
hw.setdefault("action", {})[name] = ty
|
||||
|
||||
elif full_key.startswith("observation.state."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len("observation.state.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith("observation.images."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len("observation.images.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if "action" in hw:
|
||||
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
return out
|
||||
@@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def merge_features(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
|
||||
@@ -65,8 +65,8 @@ class Pi0NewLineProcessor(ProcessorStep):
|
||||
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract."""
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -88,8 +88,8 @@ class SmolVLANewLineProcessor(ProcessorStep):
|
||||
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract."""
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Adds nothing to the features."""
|
||||
return features
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Any
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
@@ -134,6 +135,5 @@ class ToBatchProcessor:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return features (no-op for this processor)."""
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from .pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if isinstance(x, np.ndarray):
|
||||
# Keep images (uint8 HWC) and python objects as-is
|
||||
if x.dtype == np.uint8 or x.dtype == np.object_:
|
||||
return x
|
||||
# Scalars/arrays to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
# Anything else to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
|
||||
|
||||
def _from_tensor(x: Any):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
return state, images
|
||||
|
||||
|
||||
def make_obs_act_transition(
|
||||
*, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None
|
||||
) -> EnvTransition:
|
||||
return {
|
||||
TransitionKey.OBSERVATION: {} if obs is None else obs,
|
||||
TransitionKey.ACTION: {} if act is None else act,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
TransitionKey.REWARD: None,
|
||||
TransitionKey.DONE: None,
|
||||
TransitionKey.TRUNCATED: None,
|
||||
}
|
||||
|
||||
|
||||
def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"action.{k}"] = v
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"action.{k}"] = _to_tensor(arr)
|
||||
|
||||
return make_obs_act_transition(act=act_dict)
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"observation.images.{cam}"] = img
|
||||
|
||||
return make_obs_act_transition(obs=obs_dict)
|
||||
|
||||
|
||||
def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len("action.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, any]:
|
||||
"""
|
||||
Converts a single EnvTransition or an iterable of them into a flat,
|
||||
dataset-friendly dictionary for training or evaluation, according to
|
||||
the provided `features` spec.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict
|
||||
or an iterable of them (which will be merged).
|
||||
features (dict[str, dict]):
|
||||
A feature specification dictionary:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through
|
||||
|
||||
Returns:
|
||||
batch (dict[str, any]): Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action"
|
||||
- any image tensors defined in features
|
||||
- next.{reward,done,truncated}
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
"""
|
||||
action_names = features.get("action", {}).get("names", [])
|
||||
obs_state_names = features.get("observation.state", {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith("observation.images.")]
|
||||
|
||||
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
out = deepcopy(base)
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
def _ensure_transition(obj) -> EnvTransition:
|
||||
# single transition
|
||||
if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj):
|
||||
return obj
|
||||
# iterable of transitions
|
||||
if isinstance(obj, Iterable):
|
||||
items = list(obj)
|
||||
if not items:
|
||||
return {}
|
||||
acc = items[0]
|
||||
for t in items[1:]:
|
||||
acc = _merge(acc, t)
|
||||
return acc
|
||||
raise TypeError("Expected EnvTransition or iterable of them")
|
||||
|
||||
tr = _ensure_transition(transitions_or_transition)
|
||||
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, any] = {}
|
||||
|
||||
# Images passthrough
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
|
||||
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
|
||||
batch["action"] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Next.* fields
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
return batch
|
||||
@@ -141,5 +141,5 @@ class DeviceProcessor:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -257,7 +257,7 @@ class NormalizerProcessor:
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -435,7 +435,7 @@ class UnnormalizerProcessor:
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
|
||||
@@ -106,9 +106,8 @@ class VanillaObservationProcessor(ObservationProcessor):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
|
||||
@@ -23,7 +23,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
@@ -132,6 +132,7 @@ class ProcessorStepRegistry:
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessorStep(Protocol):
|
||||
"""Structural typing interface for a single processor step.
|
||||
|
||||
@@ -145,7 +146,6 @@ class ProcessorStep(Protocol):
|
||||
|
||||
**Required**:
|
||||
- ``__call__(transition: EnvTransition) -> EnvTransition``
|
||||
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||||
|
||||
Optional helper protocol:
|
||||
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||||
@@ -158,6 +158,8 @@ class ProcessorStep(Protocol):
|
||||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||||
containing torch tensors only.
|
||||
* ``reset()`` – Clear internal buffers at episode boundaries.
|
||||
* ``transform_features(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
|
||||
If present, this method will be called to aggregate the dataset features of all steps.
|
||||
|
||||
Example separation:
|
||||
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
|
||||
@@ -174,7 +176,7 @@ class ProcessorStep(Protocol):
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
|
||||
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
@@ -354,7 +356,10 @@ class RobotProcessor(ModelHubMixin):
|
||||
hook(idx, current_transition)
|
||||
|
||||
# Convert back to original format if needed
|
||||
return self.to_output(current_transition) if called_with_batch else current_transition
|
||||
if called_with_batch or self.to_output is not _default_transition_to_batch:
|
||||
return self.to_output(current_transition)
|
||||
else:
|
||||
return current_transition
|
||||
|
||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||
"""Prepare and validate transition data for processing.
|
||||
@@ -819,23 +824,15 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||
)
|
||||
|
||||
fc = getattr(step, "feature_contract", None)
|
||||
if not callable(fc):
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
|
||||
)
|
||||
|
||||
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Apply ALL steps in order. Each step must implement
|
||||
feature_contract(features) and return a dict (full or incremental schema).
|
||||
Apply ALL steps in order. Only if a step has a features method, it will be called.
|
||||
We aggregate the dataset features of all steps.
|
||||
"""
|
||||
features: dict[str, PolicyFeature] = deepcopy(initial_features)
|
||||
|
||||
for _, step in enumerate(self.steps):
|
||||
out = step.feature_contract(features)
|
||||
if not isinstance(out, dict):
|
||||
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
|
||||
out = step.transform_features(features)
|
||||
features = out
|
||||
return features
|
||||
|
||||
@@ -895,7 +892,7 @@ class ObservationProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -955,7 +952,7 @@ class ActionProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1014,7 +1011,7 @@ class RewardProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1078,7 +1075,7 @@ class DoneProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1138,7 +1135,7 @@ class TruncatedProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1203,7 +1200,7 @@ class InfoProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1249,7 +1246,7 @@ class ComplementaryDataProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1271,5 +1268,5 @@ class IdentityProcessor:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -43,7 +43,7 @@ class RenameProcessor(ObservationProcessor):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
|
||||
@@ -187,7 +187,7 @@ class TokenizerProcessor:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
|
||||
Args:
|
||||
|
||||
+117
-28
@@ -72,12 +72,19 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import rename_stats
|
||||
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -191,6 +198,36 @@ class RecordConfig:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
[ Robot ]
|
||||
V
|
||||
[ robot.get_observation() ] ---> raw_obs
|
||||
V
|
||||
[ robot_observation_processor ] ---> obs_transition
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> teleop_transition '---> policy_transition
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
[ robot.send_action() ] -- (Robot Executes)
|
||||
V
|
||||
( Transitions are merged & added to Dataset )
|
||||
V
|
||||
( Rerun Log / Loop Wait )
|
||||
"""
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_loop(
|
||||
robot: Robot,
|
||||
@@ -202,14 +239,27 @@ def record_loop(
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
control_time_s: int | None = None,
|
||||
teleop_action_processor: RobotProcessor | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessor | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessor | None = None, # runs after robot
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list):
|
||||
if isinstance(teleop, list): # For LeKiwi
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
teleop_arm = next(
|
||||
(
|
||||
@@ -226,11 +276,20 @@ def record_loop(
|
||||
)
|
||||
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None or preprocessor is not None:
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset custom pipelines
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
policy_transition = None
|
||||
teleop_transition = None
|
||||
obs_transition = None
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -240,12 +299,19 @@ def record_loop(
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
observation = robot.get_observation()
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = to_dataset_frame(
|
||||
obs_transition, dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
|
||||
if policy is not None or preprocessor is not None:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
@@ -256,37 +322,64 @@ def record_loop(
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
|
||||
policy_transition = {
|
||||
TransitionKey.ACTION: policy_action,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
elif isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen during environment reset."
|
||||
)
|
||||
continue
|
||||
# Still continue to next loop to respect timing
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy_transition is not None:
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
sent_action = robot.send_action(action)
|
||||
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
# If to_dataset_frame is provided, use it to merge the transitions.
|
||||
merged = []
|
||||
if obs_transition is not None: # The observation from the robot
|
||||
merged.append(obs_transition)
|
||||
if teleop_transition is not None: # The action from teleop
|
||||
merged.append(teleop_transition)
|
||||
if policy_transition is not None: # The action from policy
|
||||
merged.append(policy_transition)
|
||||
frame = to_dataset_frame(
|
||||
merged if len(merged) > 1 else merged[0], dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(observation, action)
|
||||
log_rerun_data([obs_transition, teleop_transition or policy_transition])
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -417,9 +510,5 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
record()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
record()
|
||||
|
||||
@@ -14,6 +14,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .so100_follower import SO100Follower
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
@@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Path to URDF file for kinematics
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
urdf_path: str | None = None
|
||||
|
||||
# End-effector frame name in the URDF
|
||||
target_frame_name: str = "gripper_frame_link"
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,447 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta:
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: dict
|
||||
motor_names: list[str]
|
||||
|
||||
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(act.pop("action.enabled", 0))
|
||||
tx = float(act.pop("action.target_x", 0.0))
|
||||
ty = float(act.pop("action.target_y", 0.0))
|
||||
tz = float(act.pop("action.target_z", 0.0))
|
||||
wx = float(act.pop("action.target_wx", 0.0))
|
||||
wy = float(act.pop("action.target_wy", 0.0))
|
||||
wz = float(act.pop("action.target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
if enabled:
|
||||
# Latch a reference at the rising edge; also be defensive if None
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
|
||||
delta_p = np.array(
|
||||
[
|
||||
tx * self.end_effector_step_sizes["x"],
|
||||
ty * self.end_effector_step_sizes["y"],
|
||||
tz * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
|
||||
self._command_when_disabled = desired.copy()
|
||||
else:
|
||||
# While disabled, keep sending the same command to avoid drift.
|
||||
if self._command_when_disabled is None:
|
||||
# If we've never had an enabled command yet, freeze current FK pose once.
|
||||
self._command_when_disabled = t_curr.copy()
|
||||
desired = self._command_when_disabled.copy()
|
||||
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": float(pos[0]),
|
||||
"action.ee.y": float(pos[1]),
|
||||
"action.ee.z": float(pos[2]),
|
||||
"action.ee.wx": float(tw[0]),
|
||||
"action.ee.wy": float(tw[1]),
|
||||
"action.ee.wz": float(tw[2]),
|
||||
}
|
||||
)
|
||||
|
||||
self._prev_enabled = enabled
|
||||
transition[TransitionKey.ACTION] = act
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessor):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
max_ee_step_m: float = 0.05
|
||||
max_ee_twist_step_rad: float = 0.20
|
||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict | None) -> dict:
|
||||
x = act.pop("action.ee.x", None)
|
||||
y = act.pop("action.ee.y", None)
|
||||
z = act.pop("action.ee.z", None)
|
||||
wx = act.pop("action.ee.wx", None)
|
||||
wy = act.pop("action.ee.wy", None)
|
||||
wz = act.pop("action.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return act
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
|
||||
# Clip position
|
||||
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
|
||||
|
||||
# Check for jumps in position
|
||||
if self._last_pos is not None:
|
||||
dpos = pos - self._last_pos
|
||||
n = float(np.linalg.norm(dpos))
|
||||
if n > self.max_ee_step_m and n > 0:
|
||||
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
|
||||
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
|
||||
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": float(pos[0]),
|
||||
"action.ee.y": float(pos[1]),
|
||||
"action.ee.z": float(pos[2]),
|
||||
"action.ee.wx": float(twist[0]),
|
||||
"action.ee.wy": float(twist[1]),
|
||||
"action.ee.wz": float(twist[2]),
|
||||
}
|
||||
)
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
self._last_pos = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# Because this is last step we specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints:
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get("action.ee.x", None)
|
||||
y = act.get("action.ee.y", None)
|
||||
z = act.get("action.ee.z", None)
|
||||
wx = act.get("action.ee.wx", None)
|
||||
wy = act.get("action.ee.wy", None)
|
||||
wz = act.get("action.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
# Nothing to do; restore what we popped and return
|
||||
act.update(
|
||||
{
|
||||
"action.ee.x": x,
|
||||
"action.ee.y": y,
|
||||
"action.ee.z": z,
|
||||
"action.ee.wx": wx,
|
||||
"action.ee.wy": wy,
|
||||
"action.ee.wz": wz,
|
||||
}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = act
|
||||
return transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"action.{name}.pos"] = float(q_target[i])
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features["action.ee.z"] = float
|
||||
features["action.ee.wx"] = float
|
||||
features["action.ee.wy"] = float
|
||||
features["action.ee.wz"] = float
|
||||
|
||||
features["observation.state.gripper.pos"] = float
|
||||
features["action.gripper.pos"] = float
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint:
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if "action.gripper" not in act:
|
||||
return transition
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
new_act = dict(act)
|
||||
new_act.pop("action.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get("action.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act["action.gripper.pos"] = gripper_pos
|
||||
new_act.pop("action.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs.update({"observation.state.gripper.pos": curr_pos})
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
features["observation.state.gripper.pos"] = float
|
||||
features["action.gripper.pos"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, obs: dict | None) -> dict:
|
||||
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
|
||||
return obs
|
||||
|
||||
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
obs.update(
|
||||
{
|
||||
"observation.state.ee.x": float(pos[0]),
|
||||
"observation.state.ee.y": float(pos[1]),
|
||||
"observation.state.ee.z": float(pos[2]),
|
||||
"observation.state.ee.wx": float(tw[0]),
|
||||
"observation.state.ee.wy": float(tw[1]),
|
||||
"observation.state.ee.wz": float(tw[2]),
|
||||
}
|
||||
)
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"observation.state.ee.{k}"] = float
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_robot_observation")
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
|
||||
def complementary_data(self, comp: dict | None) -> dict:
|
||||
comp = {} if comp is None else dict(comp)
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
comp["raw_joint_positions"] = {
|
||||
k.removesuffix(".pos"): float(v)
|
||||
for k, v in obs.items()
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return comp
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -1,200 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.feetech import FeetechMotorsBus
|
||||
|
||||
from . import SO100Follower
|
||||
from .config_so100_follower import SO100FollowerEndEffectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO100FollowerEndEffector(SO100Follower):
|
||||
"""
|
||||
SO100Follower robot with end-effector space control.
|
||||
|
||||
This robot inherits from SO100Follower but transforms actions from
|
||||
end-effector space to joint space before sending them to the motors.
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerEndEffectorConfig
|
||||
name = "so100_follower_end_effector"
|
||||
|
||||
def __init__(self, config: SO100FollowerEndEffectorConfig):
|
||||
super().__init__(config)
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.config = config
|
||||
|
||||
# Initialize the kinematics module for the so100 robot
|
||||
if self.config.urdf_path is None:
|
||||
raise ValueError(
|
||||
"urdf_path must be provided in the configuration for end-effector control. "
|
||||
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
|
||||
)
|
||||
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=self.config.urdf_path,
|
||||
target_frame_name=self.config.target_frame_name,
|
||||
)
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, Any]:
|
||||
"""
|
||||
Define action features for end-effector control.
|
||||
Returns dictionary with dtype, shape, and names.
|
||||
"""
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform action from end-effector space to joint space and send to motors.
|
||||
|
||||
Args:
|
||||
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
|
||||
or a numpy array with [delta_x, delta_y, delta_z]
|
||||
|
||||
Returns:
|
||||
The joint-space action that was sent to the motors
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Convert action to numpy array if not already
|
||||
if isinstance(action, dict):
|
||||
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action["delta_x"] * self.config.end_effector_step_sizes["x"],
|
||||
action["delta_y"] * self.config.end_effector_step_sizes["y"],
|
||||
action["delta_z"] * self.config.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
if "gripper" not in action:
|
||||
action["gripper"] = [1.0]
|
||||
action = np.append(delta_ee, action["gripper"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
|
||||
)
|
||||
action = np.zeros(4, dtype=np.float32)
|
||||
|
||||
if self.current_joint_pos is None:
|
||||
# Read current joint positions
|
||||
current_joint_pos = self.bus.sync_read("Present_Position")
|
||||
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
|
||||
self.current_joint_pos, desired_ee_pos
|
||||
)
|
||||
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
# Gripper delta action is in the range 0 - 2,
|
||||
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values_in_degrees.copy()
|
||||
self.current_joint_pos[-1] = joint_action["gripper.pos"]
|
||||
|
||||
# Send joint space action to parent class
|
||||
return super().send_action(joint_action)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
@@ -69,6 +69,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
raise ValueError(config.type)
|
||||
|
||||
|
||||
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
) -> dict[str, float]:
|
||||
|
||||
@@ -109,7 +109,7 @@ def teleop_loop(
|
||||
action = teleop.get_action()
|
||||
if display_data:
|
||||
observation = robot.get_observation()
|
||||
log_rerun_data(observation, action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
robot.send_action(action)
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_phone import PhoneConfig
|
||||
from .phone import Phone
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
class PhoneOS(Enum):
|
||||
ANDROID = "android"
|
||||
IOS = "ios"
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("phone")
|
||||
@dataclass
|
||||
class PhoneConfig(TeleoperatorConfig):
|
||||
phone_os: PhoneOS = PhoneOS.IOS
|
||||
camera_offset = np.array(
|
||||
[0.0, -0.02, 0.04]
|
||||
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
|
||||
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Docs:
|
||||
# hebi: https://docs.hebi.us/tools.html#mobile-io
|
||||
# teleop: https://github.com/SpesRobotics/teleop
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Phone(Teleoperator):
|
||||
"""
|
||||
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
|
||||
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
|
||||
|
||||
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
|
||||
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
|
||||
"""
|
||||
|
||||
config_class = PhoneConfig
|
||||
name = "phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
self._teleop = None
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
self._latest_message = None
|
||||
self._enabled: bool = False
|
||||
self._calib_pos: np.ndarray | None = None
|
||||
self._calib_rot_inv: Rotation | None = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (self.config.phone_os == PhoneOS.IOS and self._group is not None) or (
|
||||
self.config.phone_os == PhoneOS.ANDROID and self._teleop is not None
|
||||
)
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
|
||||
if group is None:
|
||||
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
|
||||
self._group = group
|
||||
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
|
||||
elif self.config.phone_os == PhoneOS.ANDROID:
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
|
||||
self._teleop_thread.start()
|
||||
logger.info(f"{self} connected, teleop stream started.")
|
||||
else:
|
||||
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
else:
|
||||
print("Touch and move on the WebXR page to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
|
||||
self._calib_pos = pos.copy()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
|
||||
"phone.enabled": bool,
|
||||
}
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
ok, pos, rot, pose = self._read_current_pose()
|
||||
if not ok:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
io = getattr(pose, "io", None)
|
||||
b = getattr(io, "b", None) if io is not None else None
|
||||
b1 = False
|
||||
if b is not None:
|
||||
b1 = bool(b.get_int(1))
|
||||
if b1:
|
||||
return pos, rot
|
||||
else:
|
||||
msg = self._latest_message or {}
|
||||
if bool(msg.get("move", False)):
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
ar_quat = getattr(pose, "ar_orientation", None)
|
||||
if ar_pos is None or ar_quat is None:
|
||||
return False, None, None, None
|
||||
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
|
||||
rot = Rotation.from_quat(quat_xyzw)
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
else:
|
||||
p = self._latest_pose
|
||||
if p is None:
|
||||
return False, None, None, None
|
||||
rot = Rotation.from_matrix(p[:3, :3])
|
||||
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
|
||||
pose = self._latest_pose
|
||||
return True, pos, rot, pose
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
time.sleep(0.001) # 1ms delay to avoid race condition
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
io = getattr(pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
for ch in range(1, 9):
|
||||
if bank_a.has_float(ch):
|
||||
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
|
||||
if bank_b:
|
||||
for ch in range(1, 9):
|
||||
if bank_b.has_int(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
|
||||
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
|
||||
else:
|
||||
msg = self._latest_message or {}
|
||||
raw_inputs["move"] = bool(msg.get("move", False))
|
||||
raw_inputs["scale"] = float(msg.get("scale", 1.0))
|
||||
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
|
||||
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
enable = bool(raw_inputs.get("b1", 0))
|
||||
else:
|
||||
enable = bool(raw_inputs.get("move", False))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
self._group = None
|
||||
else:
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
@@ -0,0 +1,87 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
"""
|
||||
Map calibrated phone pose (actions) to the inputs for robot actions
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.phone.enabled": bool,
|
||||
"action.phone.pos": np.ndarray,
|
||||
"action.phone.rot": Rotation,
|
||||
"action.phone.raw_inputs": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict | None) -> dict:
|
||||
# Pop them from the action
|
||||
enabled = act.pop("action.phone.enabled", 0)
|
||||
pos = act.pop("action.phone.pos", None)
|
||||
rot = act.pop("action.phone.rot", None)
|
||||
inputs = act.pop("action.phone.raw_inputs", {})
|
||||
|
||||
if pos is None or rot is None:
|
||||
return act
|
||||
|
||||
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
|
||||
|
||||
# Map certain inputs to certain actions
|
||||
if self.platform == PhoneOS.IOS:
|
||||
gripper = float(inputs.get("a3", 0.0))
|
||||
else:
|
||||
a = float(inputs.get("reservedButtonA", 0.0))
|
||||
b = float(inputs.get("reservedButtonB", 0.0))
|
||||
gripper = (
|
||||
a - b
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act.update(
|
||||
{
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": -pos[1] if enabled else 0.0,
|
||||
"action.target_y": pos[0] if enabled else 0.0,
|
||||
"action.target_z": pos[2] if enabled else 0.0,
|
||||
"action.target_wx": rotvec[1] if enabled else 0.0,
|
||||
"action.target_wy": rotvec[0] if enabled else 0.0,
|
||||
"action.target_wz": -rotvec[2] if enabled else 0.0,
|
||||
"action.gripper": gripper, # Still send gripper action when disabled
|
||||
}
|
||||
)
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
if val.ndim == 1:
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
|
||||
def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
|
||||
*,
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
items = data if isinstance(data, list) else ([data] if data is not None else [])
|
||||
|
||||
obs = {} if observation is None else dict(observation)
|
||||
act = {} if action is None else dict(action)
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if any(isinstance(k, TransitionKey) for k in item.keys()):
|
||||
o = item.get(TransitionKey.OBSERVATION) or {}
|
||||
a = item.get(TransitionKey.ACTION) or {}
|
||||
if isinstance(o, dict):
|
||||
obs.update(o)
|
||||
if isinstance(a, dict):
|
||||
act.update(a)
|
||||
continue
|
||||
|
||||
keys = list(item.keys())
|
||||
has_obs = any(str(k).startswith("observation.") for k in keys)
|
||||
has_act = any(str(k).startswith("action.") for k in keys)
|
||||
|
||||
if has_obs or has_act:
|
||||
if has_obs:
|
||||
obs.update(item)
|
||||
if has_act:
|
||||
act.update(item)
|
||||
else:
|
||||
# No prefixes: assume first is observation, second is action, others are observation
|
||||
if idx == 0:
|
||||
obs.update(item)
|
||||
elif idx == 1:
|
||||
act.update(item)
|
||||
else:
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
obs.update(item)
|
||||
|
||||
for k, v in obs.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
|
||||
for k, v in act.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_merge_simple_vectors():
|
||||
g1 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.x", "ee.y"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.y", "ee.z"],
|
||||
}
|
||||
}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
# Names merged with preserved order and de-dupuplication
|
||||
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
# Shape correctly recomputed from names length
|
||||
assert out["action"]["shape"] == (3,)
|
||||
|
||||
|
||||
def test_merge_multiple_groups_order_and_dedup():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = merge_features(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
|
||||
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = merge_features(g1, g2)
|
||||
|
||||
|
||||
def test_non_dict_passthrough_last_wins():
|
||||
g1 = {"misc": 123}
|
||||
g2 = {"misc": 456}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
# For non-dict entries the last one wins
|
||||
assert out["misc"] == 456
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
@@ -0,0 +1,196 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
# Scalars, arrays, and "image-like" uint8 arrays are supported
|
||||
img = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
act = {
|
||||
"ee.x": 0.5, # scalar to torch tensor
|
||||
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
|
||||
"raw_img": img, # uint8 HWC to passthrough ndarray
|
||||
}
|
||||
|
||||
tr = to_transition_teleop_action(act)
|
||||
|
||||
# Should be an EnvTransition-like dict with ACTION populated
|
||||
assert isinstance(tr, dict)
|
||||
assert TransitionKey.ACTION in tr
|
||||
assert "action.ee.x" in tr[TransitionKey.ACTION]
|
||||
assert "action.delta" in tr[TransitionKey.ACTION]
|
||||
assert "action.raw_img" in tr[TransitionKey.ACTION]
|
||||
|
||||
# Types: scalars/arrays -> torch tensor; images to np.ndarray
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
|
||||
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.delta"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
|
||||
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
|
||||
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray)
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
|
||||
|
||||
# Observation is created as empty dict by make_transition
|
||||
assert TransitionKey.OBSERVATION in tr
|
||||
assert isinstance(tr[TransitionKey.OBSERVATION], dict)
|
||||
assert tr[TransitionKey.OBSERVATION] == {}
|
||||
|
||||
|
||||
def test_to_transition_robot_observation_state_vs_images_split():
|
||||
# Create an observation with mixed content
|
||||
img = np.full((10, 20, 3), 255, dtype=np.uint8) # image (uint8 HWC)
|
||||
obs = {
|
||||
"j1.pos": 10.0, # scalar to state to torch tensor
|
||||
"j2.pos": np.float32(20.0), # scalar np to state to torch tensor
|
||||
"image_front": img, # to images passthrough
|
||||
"flag": np.int32(7), # scalar to state to torch tensor
|
||||
"arr": np.array([1.5, 2.5]), # vector to state to torch tensor
|
||||
}
|
||||
|
||||
tr = to_transition_robot_observation(obs)
|
||||
assert isinstance(tr, dict)
|
||||
assert TransitionKey.OBSERVATION in tr
|
||||
|
||||
out = tr[TransitionKey.OBSERVATION]
|
||||
# Check state keys are present and converted to tensors
|
||||
for k in ("j1.pos", "j2.pos", "flag", "arr"):
|
||||
key = f"observation.state.{k}"
|
||||
assert key in out
|
||||
v = out[key]
|
||||
if k != "arr":
|
||||
assert isinstance(v, torch.Tensor) and v.ndim == 0
|
||||
else:
|
||||
assert isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape == (2,)
|
||||
|
||||
# Check image present as is
|
||||
assert "observation.images.image_front" in out
|
||||
assert isinstance(out["observation.images.image_front"], np.ndarray)
|
||||
assert out["observation.images.image_front"].dtype == np.uint8
|
||||
assert out["observation.images.image_front"].shape == (10, 20, 3)
|
||||
|
||||
# ACTION should be empty dict by make_transition
|
||||
assert TransitionKey.ACTION in tr
|
||||
assert isinstance(tr[TransitionKey.ACTION], dict)
|
||||
assert tr[TransitionKey.ACTION] == {}
|
||||
|
||||
|
||||
def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only():
|
||||
# Build a transition with mixed action keys
|
||||
tr = {
|
||||
TransitionKey.ACTION: {
|
||||
"action.j1.pos": 11.0, # keep "j1.pos"
|
||||
"action.gripper.pos": torch.tensor(33.0), # keep: tensor accepted
|
||||
"action.ee.x": 0.5, # ignore (doesn't end with .pos)
|
||||
"misc": "ignore_me", # ignore (no 'action.' prefix)
|
||||
}
|
||||
}
|
||||
|
||||
out = to_output_robot_action(tr)
|
||||
# Only ".pos" keys with "action." prefix are retained and stripped to base names
|
||||
assert set(out.keys()) == {"j1.pos", "gripper.pos"}
|
||||
# Values converted to float
|
||||
assert isinstance(out["j1.pos"], float)
|
||||
assert isinstance(out["gripper.pos"], float)
|
||||
assert out["j1.pos"] == pytest.approx(11.0)
|
||||
assert out["gripper.pos"] == pytest.approx(33.0)
|
||||
|
||||
|
||||
def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
# Fabricate dataset features (as stored in dataset.meta["features"])
|
||||
features = {
|
||||
# Action vector: 3 elements in specific order
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": ["j1.pos", "j2.pos", "gripper.pos"],
|
||||
},
|
||||
# Observation state vector: 2 elements
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["j1.pos", "j2.pos"],
|
||||
},
|
||||
# Image spec (video/image dtype acceptable)
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["h", "w", "c"],
|
||||
},
|
||||
}
|
||||
|
||||
# Build two transitions to be merged: teleop (action) and robot obs (state/images)
|
||||
img = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8)
|
||||
|
||||
teleop_transition = {
|
||||
TransitionKey.OBSERVATION: {},
|
||||
TransitionKey.ACTION: {
|
||||
"action.j1.pos": torch.tensor(1.1),
|
||||
"action.j2.pos": torch.tensor(2.2),
|
||||
# gripper.pos missing → defaults to 0.0
|
||||
"action.ee.x": 0.5, # ignored, not in features["action"]["names"]
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"frame_is_pad": True,
|
||||
"task": "Pick cube",
|
||||
},
|
||||
}
|
||||
|
||||
robot_transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
"observation.state.j1.pos": torch.tensor(10.0),
|
||||
"observation.state.j2.pos": torch.tensor(20.0),
|
||||
"observation.images.front": img,
|
||||
},
|
||||
TransitionKey.REWARD: torch.tensor(5.0),
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {"note": "ok"},
|
||||
}
|
||||
|
||||
# Directly call the refactored function
|
||||
batch = to_dataset_frame([teleop_transition, robot_transition], features)
|
||||
|
||||
# Images passthrough
|
||||
assert "observation.images.front" in batch
|
||||
assert batch["observation.images.front"].shape == img.shape
|
||||
assert batch["observation.images.front"].dtype == np.uint8
|
||||
assert np.shares_memory(batch["observation.images.front"], img) or np.array_equal(
|
||||
batch["observation.images.front"], img
|
||||
)
|
||||
|
||||
# Observation.state vector
|
||||
assert "observation.state" in batch
|
||||
obs_vec = batch["observation.state"]
|
||||
assert isinstance(obs_vec, np.ndarray) and obs_vec.dtype == np.float32
|
||||
assert obs_vec.shape == (2,)
|
||||
assert obs_vec[0] == pytest.approx(10.0)
|
||||
assert obs_vec[1] == pytest.approx(20.0)
|
||||
|
||||
# Action vector
|
||||
assert "action" in batch
|
||||
act_vec = batch["action"]
|
||||
assert isinstance(act_vec, np.ndarray) and act_vec.dtype == np.float32
|
||||
assert act_vec.shape == (3,)
|
||||
assert act_vec[0] == pytest.approx(1.1)
|
||||
assert act_vec[1] == pytest.approx(2.2)
|
||||
assert act_vec[2] == pytest.approx(0.0) # default for missing gripper.pos
|
||||
|
||||
# Next.* metadata
|
||||
assert batch["next.reward"] == pytest.approx(5.0)
|
||||
assert batch["next.done"] is True
|
||||
assert batch["next.truncated"] is False
|
||||
|
||||
# Complementary data
|
||||
assert batch["frame_is_pad"] is True
|
||||
assert batch["task"] == "Pick cube"
|
||||
@@ -288,8 +288,8 @@ def test_serialization_methods():
|
||||
assert processor.device == device
|
||||
|
||||
|
||||
def test_feature_contract():
|
||||
"""Test that feature_contract returns features unchanged."""
|
||||
def test_features():
|
||||
"""Test that features returns features unchanged."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
|
||||
features = {
|
||||
@@ -297,7 +297,7 @@ def test_feature_contract():
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
}
|
||||
|
||||
result = processor.feature_contract(features)
|
||||
result = processor.transform_features(features)
|
||||
assert result == features
|
||||
assert result is features # Should return the same object
|
||||
|
||||
|
||||
@@ -621,10 +621,19 @@ def test_serialization_roundtrip(full_stats):
|
||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||
|
||||
# Verify features and norm_map are correctly reconstructed
|
||||
assert new_processor.features.keys() == original_processor.features.keys()
|
||||
for key in new_processor.features:
|
||||
assert new_processor.features[key].type == original_processor.features[key].type
|
||||
assert new_processor.features[key].shape == original_processor.features[key].shape
|
||||
assert (
|
||||
new_processor.transform_features(features).keys()
|
||||
== original_processor.transform_features(features).keys()
|
||||
)
|
||||
for key in new_processor.transform_features(features):
|
||||
assert (
|
||||
new_processor.transform_features(features)[key].type
|
||||
== original_processor.transform_features(features)[key].type
|
||||
)
|
||||
assert (
|
||||
new_processor.transform_features(features)[key].shape
|
||||
== original_processor.transform_features(features)[key].shape
|
||||
)
|
||||
|
||||
assert new_processor.norm_map == original_processor.norm_map
|
||||
|
||||
|
||||
@@ -410,13 +410,13 @@ def test_equivalent_with_image_dict():
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||
def test_image_processor_features_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
||||
assert "pixels" not in out
|
||||
@@ -424,13 +424,13 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||
def test_image_processor_features_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
||||
assert "observation.pixels" not in out
|
||||
@@ -438,7 +438,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||
def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
@@ -446,7 +446,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
||||
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
||||
@@ -456,14 +456,14 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||
def test_state_processor_features_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
||||
@@ -472,13 +472,13 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||
def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
}
|
||||
out = proc.feature_contract(features.copy())
|
||||
out = proc.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
||||
|
||||
@@ -26,6 +26,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
@@ -90,8 +91,8 @@ class MockStep:
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -112,8 +113,8 @@ class MockStepWithoutOptionalMethods:
|
||||
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -168,8 +169,8 @@ class MockStepWithTensorState:
|
||||
self.running_mean.zero_()
|
||||
self.running_count.zero_()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -662,8 +663,8 @@ class MockModuleStep(nn.Module):
|
||||
self.running_mean.zero_()
|
||||
self.counter = 0
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -744,8 +745,8 @@ class MockNonModuleStepWithState:
|
||||
self.step_count.zero_()
|
||||
self.history.clear()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -799,8 +800,8 @@ class MockStepWithNonSerializableParam:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -838,8 +839,8 @@ class RegisteredMockStep:
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
|
||||
@@ -1382,8 +1383,8 @@ def test_state_file_naming_with_registry():
|
||||
def load_state_dict(self, state):
|
||||
self.state_tensor = state["state_tensor"]
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
try:
|
||||
@@ -1439,8 +1440,8 @@ def test_override_with_nested_config():
|
||||
def get_config(self):
|
||||
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
try:
|
||||
@@ -1531,8 +1532,8 @@ def test_override_with_callables():
|
||||
def get_config(self):
|
||||
return {"name": self.name}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
try:
|
||||
@@ -1766,8 +1767,8 @@ def test_override_with_device_strings():
|
||||
def load_state_dict(self, state):
|
||||
self.buffer = state["buffer"]
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test feature_contract here
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
try:
|
||||
@@ -1860,21 +1861,16 @@ def test_save_load_with_custom_converter_functions():
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
"""Intentionally non-compliant: missing feature_contract."""
|
||||
"""Intentionally non-compliant: missing features."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
|
||||
def test_construction_rejects_step_without_feature_contract():
|
||||
with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"):
|
||||
RobotProcessor([NonCompliantStep()])
|
||||
|
||||
|
||||
class NonCallableStep:
|
||||
"""Intentionally non-compliant: missing __call__."""
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@@ -1893,7 +1889,7 @@ class FeatureContractAddStep:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[self.key] = self.value
|
||||
return features
|
||||
|
||||
@@ -1908,7 +1904,7 @@ class FeatureContractMutateStep:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[self.key] = self.fn(features.get(self.key))
|
||||
return features
|
||||
|
||||
@@ -1920,7 +1916,7 @@ class FeatureContractBadReturnStep:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return ["not-a-dict"]
|
||||
|
||||
|
||||
@@ -1933,12 +1929,12 @@ class FeatureContractRemoveStep:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(self.key, None)
|
||||
return features
|
||||
|
||||
|
||||
def test_feature_contract_orders_and_merges(policy_feature_factory):
|
||||
def test_features_orders_and_merges(policy_feature_factory):
|
||||
p = RobotProcessor(
|
||||
[
|
||||
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
|
||||
@@ -1946,14 +1942,14 @@ def test_feature_contract_orders_and_merges(policy_feature_factory):
|
||||
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
|
||||
]
|
||||
)
|
||||
out = p.feature_contract({})
|
||||
out = p.transform_features({})
|
||||
|
||||
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
|
||||
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_respects_initial_without_mutation(policy_feature_factory):
|
||||
def test_features_respects_initial_without_mutation(policy_feature_factory):
|
||||
initial = {
|
||||
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
|
||||
@@ -1966,7 +1962,7 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto
|
||||
),
|
||||
]
|
||||
)
|
||||
out = p.feature_contract(initial_features=initial)
|
||||
out = p.transform_features(initial_features=initial)
|
||||
|
||||
assert out["seed"].shape == (8,)
|
||||
assert out["nested"].shape == (5,)
|
||||
@@ -1977,13 +1973,7 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_type_error_on_bad_step():
|
||||
p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()])
|
||||
with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"):
|
||||
_ = p.feature_contract({})
|
||||
|
||||
|
||||
def test_feature_contract_execution_order_tracking():
|
||||
def test_features_execution_order_tracking():
|
||||
class Track:
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
@@ -1991,32 +1981,186 @@ def test_feature_contract_execution_order_tracking():
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
code = {"A": 1, "B": 2, "C": 3}[self.label]
|
||||
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
|
||||
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
|
||||
return features
|
||||
|
||||
out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({})
|
||||
out = RobotProcessor([Track("A"), Track("B"), Track("C")]).transform_features({})
|
||||
assert out["order"].shape == (1, 2, 3)
|
||||
|
||||
|
||||
def test_feature_contract_remove_key(policy_feature_factory):
|
||||
def test_features_remove_key(policy_feature_factory):
|
||||
p = RobotProcessor(
|
||||
[
|
||||
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
|
||||
FeatureContractRemoveStep("a"),
|
||||
]
|
||||
)
|
||||
out = p.feature_contract({})
|
||||
out = p.transform_features({})
|
||||
assert "a" not in out
|
||||
|
||||
|
||||
def test_feature_contract_remove_from_initial(policy_feature_factory):
|
||||
def test_features_remove_from_initial(policy_feature_factory):
|
||||
initial = {
|
||||
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
}
|
||||
p = RobotProcessor([FeatureContractRemoveStep("drop")])
|
||||
out = p.feature_contract(initial_features=initial)
|
||||
out = p.transform_features(initial_features=initial)
|
||||
assert "drop" not in out and out["keep"] == initial["keep"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddActionEEAndJointFeatures:
|
||||
"""Adds both EE and JOINT action features."""
|
||||
|
||||
def __call__(self, tr):
|
||||
return tr
|
||||
|
||||
def transform_features(self, features: dict) -> dict:
|
||||
# EE features
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
# JOINT features
|
||||
features["action.j1.pos"] = float
|
||||
features["action.j2.pos"] = float
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddObservationStateFeatures:
|
||||
"""Adds state features (and optionally an image spec to test precedence)."""
|
||||
|
||||
add_front_image: bool = False
|
||||
front_image_shape: tuple = (240, 320, 3)
|
||||
|
||||
def __call__(self, tr):
|
||||
return tr
|
||||
|
||||
def transform_features(self, features: dict) -> dict:
|
||||
# State features (mix EE and a joint state)
|
||||
features["observation.state.ee.x"] = float
|
||||
features["observation.state.j1.pos"] = float
|
||||
if self.add_front_image:
|
||||
features["observation.images.front"] = self.front_image_shape
|
||||
return features
|
||||
|
||||
|
||||
def test_aggregate_joint_action_only():
|
||||
rp = RobotProcessor([AddActionEEAndJointFeatures()])
|
||||
initial = {"front": (480, 640, 3)}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
use_videos=True,
|
||||
patterns=["action.j1.pos", "action.j2.pos"],
|
||||
)
|
||||
|
||||
# Expect only "action" with joint names
|
||||
assert "action" in out and "observation.state" not in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out["action"]["shape"] == (len(out["action"]["names"]),)
|
||||
|
||||
|
||||
def test_aggregate_ee_action_and_observation_with_videos():
|
||||
rp = RobotProcessor([AddActionEEAndJointFeatures(), AddObservationStateFeatures()])
|
||||
initial = {"front": (480, 640, 3), "side": (720, 1280, 3)}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "observation.state"],
|
||||
)
|
||||
|
||||
# Action should pack only EE names
|
||||
assert "action" in out
|
||||
assert set(out["action"]["names"]) == {"ee.x", "ee.y"}
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert "observation.state" in out
|
||||
assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out["observation.state"]["dtype"] == "float32"
|
||||
|
||||
# Cameras from initial_features appear as videos
|
||||
for cam in ("front", "side"):
|
||||
key = f"observation.images.{cam}"
|
||||
assert key in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key]["shape"] == initial[cam]
|
||||
assert out[key]["names"] == ["height", "width", "channels"]
|
||||
|
||||
|
||||
def test_aggregate_both_action_types():
|
||||
rp = RobotProcessor([AddActionEEAndJointFeatures()])
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.j1", "action.j2.pos"],
|
||||
)
|
||||
|
||||
assert "action" in out
|
||||
expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"}
|
||||
assert set(out["action"]["names"]) == expected
|
||||
assert out["action"]["shape"] == (len(expected),)
|
||||
|
||||
|
||||
def test_aggregate_images_when_use_videos_false():
|
||||
rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True)])
|
||||
initial = {"back": (480, 640, 3)}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
use_videos=False, # expect "image" dtype
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.back"
|
||||
key_front = "observation.images.front"
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
|
||||
|
||||
def test_aggregate_images_when_use_videos_true():
|
||||
rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True)])
|
||||
initial = {"back": (480, 640, 3)}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
use_videos=True,
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key_back = "observation.images.back"
|
||||
assert key in out
|
||||
assert key_back in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key_back]["dtype"] == "video"
|
||||
assert out[key_back]["shape"] == initial["back"]
|
||||
|
||||
|
||||
def test_initial_camera_not_overridden_by_step_image():
|
||||
# Step explicitly sets a different front image shape; initial has another shape.
|
||||
# aggregate_pipeline_dataset_features should keep the step's value (setdefault behavior on initial cams).
|
||||
rp = RobotProcessor([AddObservationStateFeatures(add_front_image=True, front_image_shape=(240, 320, 3))])
|
||||
initial = {"front": (480, 640, 3)} # should NOT override the step-provided (240, 320, 3)
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
use_videos=True,
|
||||
patterns=["observation.images.front"],
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
assert key in out
|
||||
assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial
|
||||
|
||||
@@ -410,7 +410,7 @@ def test_value_types_preserved():
|
||||
assert processed_obs["old_list"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_feature_contract_basic_renaming(policy_feature_factory):
|
||||
def test_features_basic_renaming(policy_feature_factory):
|
||||
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
@@ -418,7 +418,7 @@ def test_feature_contract_basic_renaming(policy_feature_factory):
|
||||
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
# Values preserved and typed
|
||||
assert out["x"] == features["a"]
|
||||
@@ -430,14 +430,14 @@ def test_feature_contract_basic_renaming(policy_feature_factory):
|
||||
assert set(features) == {"a", "b", "c"}
|
||||
|
||||
|
||||
def test_feature_contract_overlapping_keys(policy_feature_factory):
|
||||
def test_features_overlapping_keys(policy_feature_factory):
|
||||
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
||||
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
}
|
||||
out = processor.feature_contract(features)
|
||||
out = processor.transform_features(features)
|
||||
|
||||
assert set(out) == {"b", "c"}
|
||||
assert out["b"] == features["a"] # 'a' renamed to'b'
|
||||
@@ -445,7 +445,7 @@ def test_feature_contract_overlapping_keys(policy_feature_factory):
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_chained_processors(policy_feature_factory):
|
||||
def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameProcessor(
|
||||
@@ -458,7 +458,7 @@ def test_feature_contract_chained_processors(policy_feature_factory):
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = pipeline.feature_contract(initial_features=spec)
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out) == {"observation.state", "observation.image", "extra"}
|
||||
assert out["observation.state"] == spec["pos"]
|
||||
|
||||
@@ -470,7 +470,7 @@ def test_registry_functionality():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_feature_contract_basic():
|
||||
def test_features_basic():
|
||||
"""Test basic feature contract functionality."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128)
|
||||
@@ -480,7 +480,7 @@ def test_feature_contract_basic():
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
}
|
||||
|
||||
output_features = processor.feature_contract(input_features)
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert "observation.state" in output_features
|
||||
@@ -501,13 +501,13 @@ def test_feature_contract_basic():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_feature_contract_with_custom_max_length():
|
||||
def test_features_with_custom_max_length():
|
||||
"""Test feature contract with custom max_length."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64)
|
||||
|
||||
input_features = {}
|
||||
output_features = processor.feature_contract(input_features)
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that features use correct max_length
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features
|
||||
@@ -521,7 +521,7 @@ def test_feature_contract_with_custom_max_length():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_feature_contract_existing_features():
|
||||
def test_features_existing_features():
|
||||
"""Test feature contract when tokenized features already exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256)
|
||||
@@ -531,7 +531,7 @@ def test_feature_contract_existing_features():
|
||||
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
}
|
||||
|
||||
output_features = processor.feature_contract(input_features)
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Should not overwrite existing features
|
||||
assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rerun(monkeypatch):
|
||||
"""
|
||||
Provide a mock `rerun` module so tests don't depend on the real library.
|
||||
Also reload the module-under-test so it binds to this mock `rr`.
|
||||
"""
|
||||
calls = []
|
||||
|
||||
class DummyScalar:
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
|
||||
class DummyImage:
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
|
||||
def dummy_log(key, obj, **kwargs):
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
Scalar=DummyScalar,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
)
|
||||
|
||||
# Inject fake module into sys.modules
|
||||
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
|
||||
|
||||
# Now import and reload the module under test, to bind to our rerun mock
|
||||
import lerobot.utils.visualization_utils as vu
|
||||
|
||||
importlib.reload(vu)
|
||||
|
||||
# Expose both the reloaded module and the call recorder
|
||||
yield vu, calls
|
||||
|
||||
|
||||
def _keys(calls):
|
||||
"""Helper to extract just the keys logged to rr.log"""
|
||||
return [k for (k, _obj, _kw) in calls]
|
||||
|
||||
|
||||
def _obj_for(calls, key):
|
||||
"""Find the first object logged under a given key."""
|
||||
for k, obj, _kw in calls:
|
||||
if k == key:
|
||||
return obj
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def _kwargs_for(calls, key):
|
||||
for k, _obj, kw in calls:
|
||||
if k == key:
|
||||
return kw
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
# Build EnvTransition dict
|
||||
obs = {
|
||||
"observation.state.temperature": np.float32(25.0),
|
||||
# CHW image should be converted to HWC for rr.Image
|
||||
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
||||
}
|
||||
act = {
|
||||
"action.throttle": 0.7,
|
||||
# 1D array should log individual Scalars with suffix _i
|
||||
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
|
||||
}
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: act,
|
||||
}
|
||||
|
||||
vu.log_rerun_data(transition)
|
||||
|
||||
# We expect:
|
||||
# - observation.state.temperature -> Scalar
|
||||
# - observation.camera -> Image (HWC) with static=True
|
||||
# - action.throttle -> Scalar
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
expected_keys = {
|
||||
"observation.state.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
"action.vector_1",
|
||||
}
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
# Check scalar types and values
|
||||
temp_obj = _obj_for(calls, "observation.state.temperature")
|
||||
assert type(temp_obj).__name__ == "DummyScalar"
|
||||
assert temp_obj.value == pytest.approx(25.0)
|
||||
|
||||
throttle_obj = _obj_for(calls, "action.throttle")
|
||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||
assert throttle_obj.value == pytest.approx(0.7)
|
||||
|
||||
v0 = _obj_for(calls, "action.vector_0")
|
||||
v1 = _obj_for(calls, "action.vector_1")
|
||||
assert type(v0).__name__ == "DummyScalar"
|
||||
assert type(v1).__name__ == "DummyScalar"
|
||||
assert v0.value == pytest.approx(1.0)
|
||||
assert v1.value == pytest.approx(2.0)
|
||||
|
||||
# Check image handling: CHW -> HWC
|
||||
img_obj = _obj_for(calls, "observation.camera")
|
||||
assert type(img_obj).__name__ == "DummyImage"
|
||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||
|
||||
|
||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
# First dict without prefixes treated as observation
|
||||
# Second dict without prefixes treated as action
|
||||
obs_plain = {
|
||||
"temp": 1.5,
|
||||
# Already HWC image => should stay as-is
|
||||
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
||||
"none": None, # should be skipped
|
||||
}
|
||||
act_plain = {
|
||||
"throttle": 0.3,
|
||||
"vec": np.array([9, 8, 7], dtype=np.float32),
|
||||
}
|
||||
|
||||
vu.log_rerun_data([obs_plain, act_plain])
|
||||
|
||||
# Expected keys with auto-prefixes
|
||||
expected = {
|
||||
"observation.temp",
|
||||
"observation.img",
|
||||
"action.throttle",
|
||||
"action.vec_0",
|
||||
"action.vec_1",
|
||||
"action.vec_2",
|
||||
}
|
||||
logged = set(_keys(calls))
|
||||
assert logged == expected
|
||||
|
||||
# Scalars
|
||||
t = _obj_for(calls, "observation.temp")
|
||||
assert type(t).__name__ == "DummyScalar"
|
||||
assert t.value == pytest.approx(1.5)
|
||||
|
||||
throttle = _obj_for(calls, "action.throttle")
|
||||
assert type(throttle).__name__ == "DummyScalar"
|
||||
assert throttle.value == pytest.approx(0.3)
|
||||
|
||||
# Image stays HWC
|
||||
img = _obj_for(calls, "observation.img")
|
||||
assert type(img).__name__ == "DummyImage"
|
||||
assert img.arr.shape == (5, 6, 3)
|
||||
assert _kwargs_for(calls, "observation.img").get("static", False) is True
|
||||
|
||||
# Vectors
|
||||
for i, val in enumerate([9, 8, 7]):
|
||||
o = _obj_for(calls, f"action.vec_{i}")
|
||||
assert type(o).__name__ == "DummyScalar"
|
||||
assert o.value == pytest.approx(val)
|
||||
|
||||
|
||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
vu.log_rerun_data(
|
||||
None,
|
||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
||||
action={"action.a": 1.0},
|
||||
)
|
||||
|
||||
keys = set(_keys(calls))
|
||||
assert "observation.temp" in keys
|
||||
assert "observation.gray" in keys
|
||||
assert "action.a" in keys
|
||||
|
||||
temp = _obj_for(calls, "observation.temp")
|
||||
assert type(temp).__name__ == "DummyScalar"
|
||||
assert temp.value == pytest.approx(10.0)
|
||||
|
||||
img = _obj_for(calls, "observation.gray")
|
||||
assert type(img).__name__ == "DummyImage"
|
||||
assert img.arr.shape == (8, 8, 1) # remains HWC
|
||||
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
||||
|
||||
a = _obj_for(calls, "action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
Reference in New Issue
Block a user