Compare commits

..

55 Commits

Author SHA1 Message Date
Pepijn 112eb70a65 Add uniform sampling and transition smoothing 2025-11-28 17:15:57 +01:00
Pepijn 6e3b972534 add visualize subtask annotations 2025-11-28 16:59:29 +01:00
Pepijn fa5004bd8c fix formatting 2025-11-28 13:27:20 +01:00
Pepijn b98c70376b Fix visualization and change prompt 2025-11-28 12:16:16 +01:00
Pepijn 2fa045eedc fix normalization in visualization 2025-11-28 10:52:24 +01:00
Pepijn adc476d8af simplify and cleanup code and move compute_temporal_proportions to utils 2025-11-27 19:38:32 +01:00
Pepijn 73dd4f10f7 simplify 2025-11-27 17:36:00 +01:00
Pepijn 2889c0650a use task from dataset, cleanup visualizer 2025-11-27 14:14:52 +01:00
Pepijn f2ad86831d add tests, implement formula 1,2 correctly and cleanup 2025-11-27 14:04:01 +01:00
Pepijn 3ed0425d2c Remove rewind, use clip tokenizer 2025-11-26 21:06:20 +01:00
Pepijn 425eced2de use large offset for initial frame (ugly) 2025-11-26 11:53:12 +01:00
Pepijn cc2e91febe fix progress conversion and adding initial frame 2025-11-26 11:02:42 +01:00
Pepijn c66aef878c add small logging 2025-11-25 22:54:35 +01:00
Pepijn 599c2477c5 change loadig subtasks 2025-11-25 22:48:46 +01:00
Pepijn 456d9fe3ff pass dataset metadata to policy 2025-11-25 22:13:23 +01:00
Pepijn 006185ff4a revert lerobot_train changes 2025-11-25 22:09:27 +01:00
Pepijn 2dc2a3ae55 add subtask init and detection 2025-11-25 22:06:20 +01:00
Pepijn 0c99b768f4 add episode inddex to complementary data 2025-11-25 18:34:56 +01:00
Pepijn c774818eda cleanup and refactor 2025-11-25 17:47:36 +01:00
Pepijn 3b31c2d9d3 pass stats 2025-11-25 16:25:58 +01:00
Pepijn 6b6a82bbdf raise if no state key is found 2025-11-25 16:21:29 +01:00
Pepijn 7beb20819e get state input from dataset stats 2025-11-25 16:17:28 +01:00
Pepijn 9a5a0ad575 change validation 2025-11-25 16:03:13 +01:00
Pepijn 2af40615b8 add image validation 2025-11-25 14:48:52 +01:00
Pepijn 8d2fb5d298 change expected features 2025-11-25 13:51:01 +01:00
Pepijn d286ea30d4 add reward output 2025-11-25 13:44:04 +01:00
Pepijn ca67231892 update sarm processor 2025-11-25 13:40:04 +01:00
Pepijn 5245332e36 print batch size 2025-11-25 13:26:30 +01:00
Pepijn 4367348327 change order train log 2025-11-25 13:05:56 +01:00
Pepijn c2c0dbf52e cleanup 2025-11-25 11:49:27 +01:00
Pepijn 3d28dc3681 Merge branch 'main' into feat/add_rewind 2025-11-24 19:23:05 +01:00
Michel Aractingi 0f551df8f4 add absolute_to_reative_idx for remapping indicies when a subset of data is loaded (#2490) 2025-11-20 14:05:31 +01:00
Jade Choghari 6e86a69dcd feat(envs): add envs pre-post processor (#2474)
* more changes

* working changes

* more changes

* more fixes

* fix style

* more

* clean

* put axis-1

* more fixes

* more styling fixes:

* iterate on review:

* more changes

* add env processor

* style

* more changes

* add docs

* fix imports

* fix test, add to train

* Update src/lerobot/envs/factory.py

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* iterate on review

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: jade.choghari@huggingface.co <“chogharijade@gmail.com”>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-19 18:36:14 +01:00
Eugene Mironov 8a915c6b6f [RTC] Real Time Chunking for Pi0, Smolvla, Pi0.5 (#1698)
* Add Real-Time Chunking (RTC) support for flow matching models

Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Fix rtc_config attribute access in SmolVLA

Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix rtc_config attribute access in SmolVLA

* Add RTCConfig field to SmolVLAConfig

Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Refactor RTC enabled checks to use _rtc_enabled helper

Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Rename track_debug method to track

Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Use output_dir for saving all evaluation images

Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Use output_dir for saving all evaluation images

* Fix logging buffering and enable tracking when RTC config provided

- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor SmolVLA plotting to use tracker data instead of local variables

Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Move plotting logic from modeling_smolvla to eval_dataset script

Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor plotting loging

* fixup! Refactor plotting loging

* Improve visualization: separate correction plot and fix axis scaling

Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* Fix traacking

* Right kwargs for the policy

* Add tests for tracker

* Fix tests

* Drop not required methods

* Add torch compilation for eval_dataset

* delete policies

* Add matplotliv to dev

* fixup! Add matplotliv to dev

* Experiemnt with late detach

* Debug

* Fix compilation

* Add RTC to PI0

* Pi0

* Pi0 eval dataset

* fixup! Pi0 eval dataset

* Turn off compilation for pi0/pi05

* fixup! Turn off compilation for pi0/pi05

* fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* Add workable flow

* Small fixes

* Add more tests

* Add validatio at the end

* Update README

* Silent validation

* Fix tests

* Add tests for modeling_rtc

* Add tests for flow matching models with RTC

* fixup! Add tests for flow matching models with RTC

* fixup! fixup! Add tests for flow matching models with RTC

* Add one more test

* fixup! Add one more test

* Fix test to use _rtc_enabled() instead of is_rtc_enabled()

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* Add RTC initialization tests without config for PI0.5 and SmolVLA

Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix SmolVLA init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

* Fixup eval with real robot

* fixup! Fixup eval with real robot

* fixup! fixup! Fixup eval with real robot

* Extract simulator logic from eval_with real robot and add proper headers to files

* Update images

* Fix tests

* fixup! Fix tests

* add docs for rtc

* enhance doc and add images

* Fix instal instructions

---------
Co-authored-by: Ben Zhang <benzhangniu@gmail.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-19 11:19:48 +01:00
Michel Aractingi b464d9f8bc Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456)
* filter episodes in load_nested_dataset

* nit

* remove test filtering

* move import to module level

* added missing episode indices to the EpisodeAwareSampler in lerobot_train.py;
2025-11-18 17:26:41 +01:00
Pepijn 9bd69bb236 Add script to generate embedding for dataset (#2138)
* Add generate and validate script

* fix precommit

* Improve generate embeddings function by using dataset tools (#2206)

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-18 17:13:55 +01:00
Pepijn 52b080fd8c fix rewind discrepancies 2025-11-18 16:09:16 +01:00
Pepijn 0d84f4724d fix spawn 2025-11-18 15:44:24 +01:00
Pepijn 1ffdc6f49e subtasks 2025-11-18 15:28:40 +01:00
Pepijn Kooijmans f688eb160b Merge branch 'feat/add_rewind' of https://github.com/huggingface/lerobot into feat/add_rewind 2025-11-18 15:00:30 +01:00
Pepijn Kooijmans 69868360c7 add sarm 2025-11-18 15:00:05 +01:00
Pepijn 3c9149e909 small fix 2025-11-18 13:47:05 +01:00
Pepijn cf0f878dbb add annotation 2025-11-18 13:34:21 +01:00
Michel Aractingi 784cdae55a Fixes in port droid scripts (#2455)
* Fixes in port droid scripts

* revert default mem-per-cpu

* style nit

* fix relative imports

* style nit
2025-11-17 23:42:30 +01:00
Steven Palma d9e74a9d37 chore(dependencies): Bump lerobot to 0.4.2 (#2423) 2025-11-12 13:13:57 +01:00
Steven Palma a5b29d4301 chore(installation): remove libero installation patch (#2416)
* chore(installation): remove libero installation patch

* fix(ci): exclude groot for unbound deps test
2025-11-10 11:51:52 +01:00
Steven Palma a4aa316470 fix(dataset): fix data access bottleneck for faster training (#2408) 2025-11-07 21:54:44 +01:00
Michel Aractingi f6b16f6d97 fix(dataset_tools) Critical bug in modify features (#2342)
* fix bug in `_copy_data_with_feature_changes`

* Update src/lerobot/datasets/dataset_tools.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* add missing import

---------

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2025-11-04 15:56:41 +01:00
Jade Choghari df0c335a5a feat(sim): EnvHub - allow loading envs from the hub (#2121)
* add env from the hub support

* add safe loading

* changes

* add tests, docs

* more

* style/cleaning

* order

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-04 14:52:46 +01:00
Jade Choghari 87ed3a2b6e dep(upgrade): add libero as a pypi package (#2365)
* add changes

* Update pyproject.toml

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* add openpi-transformers

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* new changes

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* Update hf-libero version in pyproject.toml

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-04 10:43:52 +01:00
Jade Choghari d57d1aa197 fix(make_policy): rename mapping edge cases in training (#2332)
* fix bug

* update fixes

* add hf license

* more fixes

* add transformers

* iterate on review

* more fixes

* more fixes

* add a False test

* reduce img size

* reduce img size

* skip the test

* add

* add style
2025-10-31 13:08:42 +01:00
Pepijn 1da9eee095 make rewind pretrained policy 2025-10-28 10:29:35 +01:00
Caroline Pascal 3f8c5d9809 fix(video_key typo): fixing video_key typo in update_video_info (#2323) 2025-10-28 09:41:33 +01:00
Steven Palma d1548e1d13 docs(install): imrpove groot and libero installation instructions (#2314) 2025-10-26 15:37:41 +08:00
Pepijn d9f0c8c3ae add initial modeling 2025-10-15 12:52:33 +02:00
118 changed files with 14300 additions and 12486 deletions
+3 -3
View File
@@ -83,11 +83,11 @@ jobs:
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
+1 -1
View File
@@ -70,7 +70,7 @@ jobs:
echo "Dependencies unbound:" && cat pyproject.toml
- name: Install lerobot with all extras
run: uv sync --all-extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv
+1 -1
View File
@@ -186,7 +186,7 @@ For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
+10 -2
View File
@@ -15,8 +15,6 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials"
@@ -41,6 +39,14 @@
title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
- local: async
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
title: "Inference"
- sections:
- local: envhub
title: Environments from the Hub
- local: il_sim
title: Imitation Learning in Sim
- local: libero
@@ -57,6 +63,8 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101
+418
View File
@@ -0,0 +1,418 @@
# Environment Processors
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
## Why Environment Processors?
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
## The Processing Pipeline
Here's how data flows through the complete processing pipeline during evaluation:
```python
# In lerobot_eval.py rollout() function:
# 1. Raw environment observation (numpy arrays, various formats)
raw_observation = env.step(action)
# 2. Convert numpy to torch, normalize images [0,1]
observation = preprocess_observation(raw_observation)
# 3. Add task metadata (for multi-task environments)
observation = add_envs_task(env, observation)
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
# - Flatten robot states
# - Rotate images to match dataset conventions
# - Handle environment-specific coordinate systems
observation = env_preprocessor(observation)
# 5. POLICY-SPECIFIC preprocessing
# - Normalize with dataset statistics
# - Add batch dimensions
# - Move to GPU
# - Tokenize language instructions
observation = preprocessor(observation)
# 6. Policy inference
action = policy.select_action(observation)
# 7. POLICY-SPECIFIC postprocessing
# - Unnormalize actions
# - Remove batch dimensions
action = postprocessor(action)
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
# - Convert action formats if needed
# - Apply environment-specific constraints
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# 9. Execute in environment
env.step(action)
```
## The Benefits
### 1. **Separation of Concerns**
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
```python
# ❌ Before: Mixed concerns
class LiberoVLAPolicy:
def preprocess(self, obs):
# Environment-specific: Flatten robot state (shouldn't be in policy!)
state = self._flatten_robot_state(obs["robot_state"])
# Policy-specific: Normalize with dataset stats
state = self.normalizer(state)
return state
# ✅ After: Clear separation
# Environment processor: Handles LIBERO's nested robot state
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
# Policy processor: Handles model requirements
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
```
### 2. **Flexibility and Reusability**
The same policy can work with different environment processors, and the same environment processor can work with different policies:
```python
# Use SmolVLA policy with LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
```
### 3. **Easier Experimentation**
Want to try different state representations for LIBERO? Just create a new processor:
```python
# Original: 8D state (pos + quat→axisangle + gripper)
@ProcessorStepRegistry.register("libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
gripper = robot_state["gripper"]["qpos"] # 2D
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
return state
# Experiment: Add velocity for better control
@ProcessorStepRegistry.register("libero_velocity_processor")
class LiberoVelocityProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
# Include velocities for 14D state
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
gripper_pos = robot_state["gripper"]["qpos"] # 2D
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
gripper_pos, gripper_vel], dim=-1) # 14D
return state
```
### 4. **Cleaner Environment Code**
Environments expose **all available data** without needing to know what downstream models will use:
```python
# LIBERO environment exposes full robot state
observation = {
"pixels": {"image": img, "image2": img2},
"robot_state": {
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
"gripper": {"qpos": ..., "qvel": ...},
"joints": {"pos": ..., "vel": ...}
}
}
# Environment processor decides what to use
# Policy processor handles model-specific transformations
```
## Using Environment Processors
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
from lerobot.envs.factory import make_env_pre_post_processors
from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
```
### Implementation in `envs/factory.py`
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:
```python
def eval_main(cfg: EvalPipelineConfig):
# Create environment
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
# Create policy
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
# Create policy processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
)
# Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
# Run evaluation with both processor types
eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor, # Environment-specific
env_postprocessor=env_postprocessor, # Environment-specific
preprocessor=preprocessor, # Policy-specific
postprocessor=postprocessor, # Policy-specific
n_episodes=cfg.eval.n_episodes,
)
```
## Example: LIBERO Environment Processor
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
from lerobot.processor.pipeline import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
**State Processing:**
- Extracts end-effector position (3D)
- Converts quaternion to axis-angle representation (3D)
- Extracts gripper joint positions (2D)
- Concatenates into 8D state vector
**Image Processing:**
- Rotates images 180° to match HuggingFaceVLA/libero convention
"""
def _process_observation(self, observation):
processed_obs = observation.copy()
# Process images: Flip 180° for camera convention
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = processed_obs[key]
img = torch.flip(img, dims=[2, 3]) # Flip H and W
processed_obs[key] = img
# Process robot_state: Flatten to 8D vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
eef_pos = robot_state["eef"]["pos"] # (B, 3)
eef_quat = robot_state["eef"]["quat"] # (B, 4)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
state = state.float()
processed_obs["observation.state"] = state
return processed_obs
```
### Why These Transformations?
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
- Selects the relevant components (pos, quat, gripper)
- Converts quaternion to axis-angle (more suitable for learning)
- Flattens to a single 8D vector that policies expect
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
## Adding Environment Processors for New Environments
To add environment processors for a new environment:
### 1. Create the Processor Step
```python
# In src/lerobot/processor/env_processor.py
@dataclass
@ProcessorStepRegistry.register(name="myenv_processor")
class MyEnvProcessorStep(ObservationProcessorStep):
"""Process observations from MyEnv."""
def _process_observation(self, observation):
processed = observation.copy()
# Your environment-specific transformations
if "myenv.specific.state" in processed:
state = processed.pop("myenv.specific.state")
# Transform to standard format
processed["observation.state"] = self._transform_state(state)
return processed
```
### 2. Update the Factory
```python
# In src/lerobot/envs/factory.py
def make_env_pre_post_processors(env_cfg: EnvConfig):
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### 3. Use in Evaluation
No changes needed! The evaluation script automatically uses the appropriate processor:
```bash
lerobot-eval \
--policy.path=lerobot/my_policy \
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
--eval.n_episodes=10
```
## Future: Environment Postprocessors
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
### Action Space Transformations
```python
@dataclass
class MyEnvActionPostprocessor(ProcessorStep):
"""Convert policy actions to environment-specific format."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Convert from Cartesian to joint space
if self.action_space == "joint":
action = self.ik_solver(action)
# Example: Apply environment-specific safety limits
action = torch.clamp(action, self.min_action, self.max_action)
transition["action"] = action
return transition
```
### Coordinate System Conversions
```python
@dataclass
class CoordinateTransformPostprocessor(ProcessorStep):
"""Transform actions between coordinate systems."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Policy outputs in world frame, env expects base frame
action = self.world_to_base_transform(action)
transition["action"] = action
return transition
```
## Best Practices
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
## Summary
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
- ✅ Enables easy experimentation with different state representations
- ✅ Allows policies to work seamlessly across different environments
- ✅ Keeps environment code focused on simulation/hardware interface
- ✅ Makes processor pipelines more maintainable and debuggable
- ✅ Follows the single responsibility principle
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
+424
View File
@@ -0,0 +1,424 @@
# Loading Environments from the Hub
The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
## Overview
With EnvHub, you can:
- Load environments from the Hub instantly
- Share your custom simulation tasks with the community
- Version control your environments using Git
- Distribute complex physics simulations without packaging hassles
## Quick Start
Loading an environment from the Hub is as simple as:
```python
from lerobot.envs.factory import make_env
# Load a hub environment (requires explicit consent to run remote code)
env = make_env("lerobot/cartpole-env", trust_remote_code=True)
```
<Tip warning={true}>
**Security Notice**: Loading environments from the Hub executes Python code
from third-party repositories. Only use `trust_remote_code=True` with
repositories you trust. We strongly recommend pinning to a specific commit
hash for reproducibility and security.
</Tip>
## What is EnvHub?
EnvHub is a framework that allows researchers and developers to:
1. **Publish environments** to the Hugging Face Hub as Git repositories
2. **Load environments** dynamically without installing them as packages
3. **Version and track** environment changes using Git semantics
4. **Discover** new simulation tasks shared by the community
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
## Repository Structure
To make your environment loadable from the Hub, your repository must contain at minimum:
### Required Files
**`env.py`** (or custom Python file)
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
- This function should return one of:
- A `gym.vector.VectorEnv` (most common)
- A single `gym.Env` (will be automatically wrapped)
- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
### Optional Files
**`requirements.txt`**
- List any additional dependencies your environment needs
- Users will need to install these manually before loading your environment
**`README.md`**
- Document your environment: what task it implements, observation/action spaces, rewards, etc.
- Include usage examples and any special setup instructions
**`.gitignore`**
- Exclude unnecessary files from your repository
### Example Repository Structure
```
my-environment-repo/
├── env.py # Main environment definition (required)
├── requirements.txt # Dependencies (optional)
├── README.md # Documentation (recommended)
├── assets/ # Images, videos, etc. (optional)
│ └── demo.gif
└── configs/ # Config files if needed (optional)
└── task_config.yaml
```
## Creating Your Environment Repository
### Step 1: Define Your Environment
Create an `env.py` file with a `make_env` function:
```python
# env.py
import gymnasium as gym
def make_env(n_envs: int = 1, use_async_envs: bool = False):
"""
Create vectorized environments for your custom task.
Args:
n_envs: Number of parallel environments
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv
Returns:
gym.vector.VectorEnv or dict mapping suite names to vectorized envs
"""
def _make_single_env():
# Create your custom environment
return gym.make("CartPole-v1")
# Choose vector environment type
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Create vectorized environment
vec_env = env_cls([_make_single_env for _ in range(n_envs)])
return vec_env
```
### Step 2: Test Locally
Before uploading, test your environment locally:
```python
from lerobot.envs.utils import _load_module_from_path, _call_make_env, _normalize_hub_result
# Load your module
module = _load_module_from_path("./env.py")
# Test the make_env function
result = _call_make_env(module, n_envs=2, use_async_envs=False)
normalized = _normalize_hub_result(result)
# Verify it works
suite_name = next(iter(normalized))
env = normalized[suite_name][0]
obs, info = env.reset()
print(f"Observation shape: {obs.shape if hasattr(obs, 'shape') else type(obs)}")
env.close()
```
### Step 3: Upload to the Hub
Upload your repository to Hugging Face:
```bash
# Install huggingface_hub if needed
pip install huggingface_hub
# Login to Hugging Face
huggingface-cli login
# Create a new repository
huggingface-cli repo create my-custom-env --type space --org my-org
# Initialize git and push
git init
git add .
git commit -m "Initial environment implementation"
git remote add origin https://huggingface.co/my-org/my-custom-env
git push -u origin main
```
Alternatively, use the `huggingface_hub` Python API:
```python
from huggingface_hub import HfApi
api = HfApi()
# Create repository
api.create_repo("my-custom-env", repo_type="space")
# Upload files
api.upload_folder(
folder_path="./my-env-folder",
repo_id="username/my-custom-env",
repo_type="space",
)
```
## Loading Environments from the Hub
### Basic Usage
```python
from lerobot.envs.factory import make_env
# Load from the hub
envs_dict = make_env(
"username/my-custom-env",
n_envs=4,
trust_remote_code=True
)
# Access the environment
suite_name = next(iter(envs_dict))
env = envs_dict[suite_name][0]
# Use it like any gym environment
obs, info = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
```
### Advanced: Pinning to Specific Versions
For reproducibility and security, pin to a specific Git revision:
```python
# Pin to a specific branch
env = make_env("username/my-env@main", trust_remote_code=True)
# Pin to a specific commit (recommended for papers/experiments)
env = make_env("username/my-env@abc123def456", trust_remote_code=True)
# Pin to a tag
env = make_env("username/my-env@v1.0.0", trust_remote_code=True)
```
### Custom File Paths
If your environment definition is not in `env.py`:
```python
# Load from a custom file
env = make_env("username/my-env:custom_env.py", trust_remote_code=True)
# Combine with version pinning
env = make_env("username/my-env@v1.0:envs/task_a.py", trust_remote_code=True)
```
### Async Environments
For better performance with multiple environments:
```python
envs_dict = make_env(
"username/my-env",
n_envs=8,
use_async_envs=True, # Use AsyncVectorEnv for parallel execution
trust_remote_code=True
)
```
## URL Format Reference
The hub URL format supports several patterns:
| Pattern | Description | Example |
| -------------------- | ------------------------------ | -------------------------------------- |
| `user/repo` | Load `env.py` from main branch | `make_env("lerobot/pusht-env")` |
| `user/repo@revision` | Load from specific revision | `make_env("lerobot/pusht-env@main")` |
| `user/repo:path` | Load custom file | `make_env("lerobot/envs:pusht.py")` |
| `user/repo@rev:path` | Revision + custom file | `make_env("lerobot/envs@v1:pusht.py")` |
## Multi-Task Environments
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
```python
def make_env(n_envs: int = 1, use_async_envs: bool = False):
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Return dict: {suite_name: {task_id: VectorEnv}}
return {
"suite_1": {
0: env_cls([lambda: gym.make("Task1-v0") for _ in range(n_envs)]),
1: env_cls([lambda: gym.make("Task2-v0") for _ in range(n_envs)]),
},
"suite_2": {
0: env_cls([lambda: gym.make("Task3-v0") for _ in range(n_envs)]),
}
}
```
## Security Considerations
<Tip warning={true}>
**Important**: The `trust_remote_code=True` flag is required to execute
environment code from the Hub. This is by design for security.
</Tip>
When loading environments from the Hub:
1. **Review the code first**: Visit the repository and inspect `env.py` before loading
2. **Pin to commits**: Use specific commit hashes for reproducibility
3. **Check dependencies**: Review `requirements.txt` for suspicious packages
4. **Use trusted sources**: Prefer official organizations or well-known researchers
5. **Sandbox if needed**: Run untrusted code in isolated environments (containers, VMs)
Example of safe usage:
```python
# ❌ BAD: Loading without inspection
env = make_env("random-user/untrusted-env", trust_remote_code=True)
# ✅ GOOD: Review code, then pin to specific commit
# 1. Visit https://huggingface.co/trusted-org/verified-env
# 2. Review the env.py file
# 3. Copy the commit hash
env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True)
```
## Example: CartPole from the Hub
Here's a complete example using the reference CartPole environment:
```python
from lerobot.envs.factory import make_env
import numpy as np
# Load the environment
envs_dict = make_env("lerobot/cartpole-env", n_envs=4, trust_remote_code=True)
# Get the vectorized environment
suite_name = next(iter(envs_dict))
env = envs_dict[suite_name][0]
# Run a simple episode
obs, info = env.reset()
done = np.zeros(env.num_envs, dtype=bool)
total_reward = np.zeros(env.num_envs)
while not done.all():
# Random policy
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
done = terminated | truncated
print(f"Average reward: {total_reward.mean():.2f}")
env.close()
```
## Benefits of EnvHub
### For Environment Authors
- **Easy distribution**: No PyPI packaging required
- **Version control**: Use Git for environment versioning
- **Rapid iteration**: Push updates instantly
- **Documentation**: Hub README renders beautifully
- **Community**: Reach LeRobot users directly
### For Researchers
- **Quick experiments**: Load any environment in one line
- **Reproducibility**: Pin to specific commits
- **Discovery**: Browse environments on the Hub
- **No conflicts**: No need to install conflicting packages
### For the Community
- **Growing ecosystem**: More diverse simulation tasks
- **Standardization**: Common `make_env` API
- **Collaboration**: Fork and improve existing environments
- **Accessibility**: Lower barrier to sharing research
## Troubleshooting
### "Refusing to execute remote code"
You must explicitly pass `trust_remote_code=True`:
```python
env = make_env("user/repo", trust_remote_code=True)
```
### "Module X not found"
The hub environment has dependencies you need to install:
```bash
# Check the repo's requirements.txt and install dependencies
pip install gymnasium numpy
```
### "make_env not found in module"
Your `env.py` must expose a `make_env` function:
```python
def make_env(n_envs: int, use_async_envs: bool):
# Your implementation
pass
```
### Environment returns wrong type
The `make_env` function must return:
- A `gym.vector.VectorEnv`, or
- A single `gym.Env`, or
- A dict `{suite_name: {task_id: VectorEnv}}`
## Best Practices
1. **Document your environment**: Include observation/action space descriptions, reward structure, and termination conditions in your README
2. **Add requirements.txt**: List all dependencies with versions
3. **Test thoroughly**: Verify your environment works locally before pushing
4. **Use semantic versioning**: Tag releases with version numbers
5. **Add examples**: Include usage examples in your README
6. **Keep it simple**: Minimize dependencies when possible
7. **License your work**: Add a LICENSE file to clarify usage terms
## Future Directions
The EnvHub ecosystem enables exciting possibilities:
- **GPU-accelerated physics**: Share Isaac Gym or Brax environments
- **Photorealistic rendering**: Distribute environments with advanced graphics
- **Multi-agent scenarios**: Complex interaction tasks
- **Real-world simulators**: Digital twins of physical setups
- **Procedural generation**: Infinite task variations
- **Domain randomization**: Pre-configured DR pipelines
As more researchers and developers contribute, the diversity and quality of available environments will grow, benefiting the entire robotics learning community.
## See Also
- [Hugging Face Hub Documentation](https://huggingface.co/docs/hub/en/index)
- [Gymnasium Documentation](https://gymnasium.farama.org/index.html)
- [Example Hub Environment](https://huggingface.co/lerobot/cartpole-env)
+4 -1
View File
@@ -40,7 +40,7 @@ python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} i
3. Install LeRobot by running:
```bash
pip install lerobot[groot] # consider also installing libero,dev and test tags
pip install lerobot[groot]
```
## Usage
@@ -83,6 +83,9 @@ accelerate launch \
### Libero Benchmark Results
> [!NOTE]
> Follow our instructions for Libero usage: [Libero](./libero)
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
| Benchmark | LeRobot Implementation | GR00T Reference |
+1 -1
View File
@@ -82,7 +82,7 @@ For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
-328
View File
@@ -1,328 +0,0 @@
# OpenArms Robot
OpenArms is a 7 DOF robotic arm with a gripper, designed by [Enactic, Inc.](https://www.enactic.com/) It uses Damiao motors controlled via CAN bus communication and MIT control mode for smooth, precise motion.
## Hardware Overview
- **7 DOF per arm** (14 DOF total for dual arm setup)
- **1 gripper per arm** (2 grippers total)
- **Damiao motors** with 4 different types:
- **DM8009** (DM-J8009P-2EC) for shoulders (J1, J2) - high torque
- **DM4340** for shoulder rotation and elbow (J3, J4)
- **DM4310** (DM-J4310-2EC V1.1) for wrist (J5, J6, J7) and gripper (J8)
- **24V power supply** required
- **CAN interface device**:
- **Linux**: Any SocketCAN-compatible adapter
- **macOS**: CANable, PEAK PCAN-USB, or Kvaser USBcan
- Proper CAN wiring (CANH, CANL, 120Ω termination)
## Motor Configuration
Each arm has the following motor configuration based on the [OpenArm setup guide](https://docs.openarm.dev/software/setup/):
| Joint | Motor | Motor Type | Sender CAN ID | Receiver ID | Description |
|-------|-------|------------|---------------|-------------|-------------|
| J1 | joint_1 | DM8009 | 0x01 | 0x11 | Shoulder pan |
| J2 | joint_2 | DM8009 | 0x02 | 0x12 | Shoulder lift |
| J3 | joint_3 | DM4340 | 0x03 | 0x13 | Shoulder rotation |
| J4 | joint_4 | DM4340 | 0x04 | 0x14 | Elbow flex |
| J5 | joint_5 | DM4310 | 0x05 | 0x15 | Wrist roll |
| J6 | joint_6 | DM4310 | 0x06 | 0x16 | Wrist pitch |
| J7 | joint_7 | DM4310 | 0x07 | 0x17 | Wrist rotation |
| J8 | gripper | DM4310 | 0x08 | 0x18 | Gripper |
For dual arm setups, the left arm uses IDs 0x09-0x10 for joints 1-8 with the same motor types.
## Quick Start
```bash
# Install system dependencies
sudo apt install can-utils iproute2
# Install LeRobot with OpenArms support
pip install -e ".[openarms]"
```
## Setup Guide
### Step 1: Motor ID Configuration
**IMPORTANT**: Before using the robot, motors must be configured with the correct CAN IDs.
Refer to the [OpenArm Motor ID Configuration Guide](https://docs.openarm.dev/software/setup/motor-id) for detailed instructions using the Damiao Debugging Tools on Windows.
Key points:
- Each motor needs a unique **Sender CAN ID** (0x01-0x08)
- Each motor needs a unique **Receiver/Master ID** (0x11-0x18)
- Use the Damiao Debugging Tools to set these IDs
### Step 2: Setup CAN Interface
Configure your CAN interface as described in the [OpenArm CAN Setup Guide](https://docs.openarm.dev/software/setup/can-setup):
#### Linux (SocketCAN)
```bash
# Find your CAN interface
ip link show
# Configure can0, 1, 2, 3
sudo ip link set can0 down
sudo ip link set can0 type can bitrate 1000000
sudo ip link set can0 up
sudo ip link set can1 down
sudo ip link set can1 type can bitrate 1000000
sudo ip link set can1 up
sudo ip link set can2 down
sudo ip link set can2 type can bitrate 1000000
sudo ip link set can2 up
sudo ip link set can3 down
sudo ip link set can3 type can bitrate 1000000
sudo ip link set can3 up
# Verify configuration
ip link show can0
```
or run:
`examples/openarms/setup_can.sh`
### Testing canbus and motor connection
Please run this script to check if all motors can be found and to find your can-fd speed: `python examples/openarms/debug_can_communication.py`
## Usage
### Basic Setup
```python
from lerobot.robots.openarms import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
# Configure for dual arm setup
config = OpenArmsFollowerConfig(
port="can0",
can_interface="socketcan", # Or "auto" for auto-detection
id="openarms_dual",
is_dual_arm=True,
)
robot = OpenArmsFollower(config)
robot.connect()
```
### Calibration
On first use, you'll need to calibrate the robot:
```python
robot.calibrate()
```
The calibration process will:
1. Disable torque on all motors
2. Ask you to position arms in **hanging position with grippers closed**
3. Set this as the zero position
4. Ask you to move each joint through its full range
5. Record min/max positions for each joint
6. Save calibration to file
### Reading Observations
The robot provides comprehensive state information:
```python
observation = robot.get_observation()
# Observation includes for each motor:
# - {motor_name}.pos: Position in degrees
# - {motor_name}.vel: Velocity in degrees/second
# - {motor_name}.torque: Motor torque
# - {camera_name}: Camera images (if configured)
print(f"Right arm joint 1 position: {observation['right_joint_1.pos']:.1f}°")
print(f"Right arm joint 1 velocity: {observation['right_joint_1.vel']:.1f}°/s")
print(f"Right arm joint 1 torque: {observation['right_joint_1.torque']:.3f} N·m")
```
### Sending Actions
```python
# Send target positions (in degrees)
action = {
"right_joint_1.pos": 45.0,
"right_joint_2.pos": -30.0,
# ... all joints
"right_gripper.pos": 45.0, # Half-closed
}
actual_action = robot.send_action(action)
```
### Gripper Control
```python
# Open gripper
robot.open_gripper(arm="right")
# Close gripper
robot.close_gripper(arm="right")
```
## Safety Features
### 1. Maximum Relative Target
Limits how far a joint can move in a single command to prevent sudden movements:
```python
config = OpenArmsFollowerConfig(
port="can0",
# Limit all joints to 10 degrees per command
max_relative_target=10.0,
# Or set per-motor limits
max_relative_target={
"right_joint_1": 15.0, # Slower moving joint
"right_joint_2": 10.0,
"right_gripper": 5.0, # Very slow gripper
}
)
```
**How it works**: If current position is 50° and you command 80°, with `max_relative_target=10.0`, the robot will only move to 60° in that step.
### 2. Torque Limits
Control maximum torque output, especially important for grippers and teleoperation:
```python
config = OpenArmsFollowerConfig(
port="can0",
# Gripper torque limit (fraction of motor's max torque)
gripper_torque_limit=0.5, # 50% of max torque
)
```
Lower torque limits prevent damage when gripping delicate objects.
### 3. MIT Control Gains
Control responsiveness and stability via PID-like gains:
```python
config = OpenArmsFollowerConfig(
port="can0",
position_kp=10.0, # Position gain (higher = more responsive)
position_kd=0.5, # Velocity damping (higher = more damped)
)
```
**Guidelines**:
- **For following (robot)**: Higher gains for responsiveness
- `position_kp=10.0`, `position_kd=0.5`
- **For teleoperation (leader)**: Lower gains or disable torque for manual movement
- `manual_control=True` (torque disabled)
### 4. Velocity Limits
Velocity limits are enforced by the Damiao motors based on motor type. For DM4310:
- Max velocity: 30 rad/s ≈ 1718°/s
The motors will automatically limit velocity to safe values.
## Teleoperation
### Leader Arm Setup
The leader arm is moved manually (torque disabled) to generate commands:
```python
from lerobot.teleoperators.openarms import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
config = OpenArmsLeaderConfig(
port="can1", # Separate CAN interface for leader
id="openarms_leader",
manual_control=True, # Torque disabled for manual movement
is_dual_arm=True,
)
leader = OpenArmsLeader(config)
leader.connect()
# Read current position as action
action = leader.get_action()
# action contains positions for all joints in degrees
```
### Safety Considerations for Teleoperation
1. **Use separate CAN interfaces** for leader and follower to avoid conflicts
2. **Enable max_relative_target** on follower to smooth abrupt movements
3. **Lower torque limits** on follower to prevent damage from tracking errors
4. **Test with one arm** before enabling dual arm teleoperation
5. **Have emergency stop** ready (power switch or CAN disable)
```python
# Recommended follower config for teleoperation
follower_config = OpenArmsFollowerConfig(
port="can0",
max_relative_target=5.0, # Small steps for smooth following
gripper_torque_limit=0.3, # Low torque for safety
position_kp=5.0, # Lower gains for gentler following
position_kd=0.3,
)
```
## Troubleshooting
### Motor Shaking/Unstable
- **Lower control gains**: Reduce `position_kp` and `position_kd`
- **Check calibration**: Re-run calibration procedure
- **Verify power**: Insufficient current can cause instability
- **Check mechanical**: Loose connections, binding, or damaged components
### CAN Bus Errors
```bash
# Check for errors
ip -s link show can0
# Reset CAN interface
sudo ip link set can0 down
sudo ip link set can0 up
```
### Control Mode
OpenArms uses **MIT control mode** which allows simultaneous control of:
- Position (degrees)
- Velocity (degrees/second)
- Torque (N·m)
- Position gain (Kp)
- Velocity damping (Kd)
### Communication
- **Protocol**: CAN 2.0 at 1 Mbps (or CAN-FD at 5 Mbps)
- **Frame format**: Standard 11-bit IDs
- **Update rate**: Typically 50-100 Hz depending on motor count
- **Latency**: ~10-20ms per motor command
## References
- [OpenArm Official Documentation](https://docs.openarm.dev/)
- [OpenArm Setup Guide](https://docs.openarm.dev/software/setup/)
- [Motor ID Configuration](https://docs.openarm.dev/software/setup/motor-id)
- [CAN Interface Setup](https://docs.openarm.dev/software/setup/can-setup)
- [Motor Communication Test](https://docs.openarm.dev/software/setup/configure-test)
- [Damiao Motor Documentation](https://wiki.seeedstudio.com/damiao_series/)
- [Enactic GitHub](https://github.com/enactic/openarm_can)
+5
View File
@@ -28,6 +28,11 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
+5
View File
@@ -36,6 +36,11 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
+188
View File
@@ -0,0 +1,188 @@
# Real-Time Chunking (RTC)
Real-Time Chunking (RTC) is an inference-time method that allows large, flow-matching based robotic policies, such as [Pi0](./pi0), [Pi0.5](./pi05), and [SmolVLA](./smolvla), to produce smooth, continuous, and reactive motion despite having high inference latency.
These policies generate chunks of future actions (e.g., 50 steps at a time) instead of single actions.
Because the models are large, producing each chunk takes longer than the time it takes the robot to execute it.
Naively executing chunks leads to problems such as pauses, jerky transitions, or sudden changes in strategy whenever the next chunk arrives late or disagrees with the previously executed actions.
RTC solves this by asynchronously generating the next chunk while the robot continues executing the current one, and by guiding the new chunk so it aligns smoothly with the portion of the previous chunk that has already been executed.
## How RTC Works (simplified)
RTC lets the robot think ahead while its still moving. When the robot is carrying out one chunk of actions, RTC starts creating the next chunk early.
But since the robot has already moved a bit by the time the new chunk is ready, RTC has to make sure the new chunk still lines up smoothly with what the robot is currently doing.
To do this, RTC treats the beginning of the new chunk like an inpainting or “fill-in-the-gaps” problem:
it gently adjusts the first part of the new chunk so it blends naturally with the robots ongoing motion. The result is no pauses, no sudden jumps.
In technical terms, RTC adds a guidance term to the flow-matching denoising process that forces the overlapping timesteps of the new chunk to stay close to the executed portion of the previous chunk, typically using a soft transition mask.
## Quick Start
### Installation
RTC is built into LeRobot. Just install the policy dependencies you need:
```bash
# For Pi0 or Pi0.5
pip install -e ".[pi]"
# For SmolVLA
pip install -e ".[smolvla]"
```
### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
from lerobot.policies.pi0 import PI0Policy, PI0Config
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.action_queue import ActionQueue
# Load Pi0 with RTC enabled
policy_cfg = PI0Config()
# Enable RTC
policy_cfg.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10, # How many steps to blend with previous chunk
max_guidance_weight=10.0, # How strongly to enforce consistency
prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
)
# Load the policy
policy = PI0Policy.from_pretrained("lerobot/pi0_base", policy_cfg=policy_cfg, device="cuda")
# Now use predict_action_chunk with RTC parameters
inference_delay = 4 # How many steps of inference latency, this values should be calculated based on the inference latency of the policy
# Initialize the action queue
action_queue = ActionQueue(policy_cfg.rtc_config)
# Start in a separate thread with the following function
def get_actions():
while True:
if should_get_actions:
prev_actions = action_queue.get_left_over()
obs = get_robot_observations(robot)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
action_queue.merge(
actions, actions, inference_delay
)
for step in range(num_steps):
action = action_queue.get()
# Execute the first N actions
execute_actions(action)
```
## Key Parameters
`RTCConfig` has the following parameters to tune:
**`execution_horizon`**: How many timesteps from the previous chunk to maintain consistency with. Higher values mean smoother transitions but potentially less reactivity.
Typical values: 8-12 steps
```python
RTCConfig(execution_horizon=10)
```
**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
**`prefix_attention_schedule`**: How to weight consistency across the overlap region.
- `LINEAR`: Linear decay from inference_delay to execution_horizon
- `EXP`: Exponential decay (recommended for getting started)
- `ONES`: Full weight across entire execution_horizon
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.
## Testing RTC Offline
Before running on a real robot, test RTC with dataset samples to visualize how it works:
```bash
python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \
--device=cuda
```
The script generates a visualization of the denoising process, comparing standard generation (left) with RTC (right). In the RTC plots, you can see how the first few steps (blue/purple lines) are guided to match the red ground truth trajectory (previous chunk's tail), ensuring a smooth transition between chunks.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/flow_matching.png"
alt="Denoising steps with and without RTC"
width="100%"
/>
</p>
## Testing RTC with a Real Robot
```bash
python examples/rtc/eval_with_real_robot.py \
--policy.path=${HF_USERNAME}/policy_repo_id \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120 \
--device=cuda
```
## How It Differs from the Async Inference in LeRobot
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
| Aspect | Async Inference | RTC |
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
**Use both together** for maximum smoothness and reactivity!
## Advanced: Debug Tracking
RTC includes built-in debug tracking to help you understand what's happening during inference:
```python
# Enable debug tracking
policy_cfg.rtc_config.debug = True
policy_cfg.rtc_config.debug_maxlen = 100
# After inference, access debug data
debug_data = policy.rtc_processor.get_debug_data()
# Visualize denoising steps, corrections, etc.
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
## References
- [Smooth-As-Butter Robot Policies](https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html) - Excellent technical explanation with real robot results
- [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/research/real_time_chunking) - Original paper and research
- [Kinetix RTC Implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) - Reference implementation from Physical Intelligence
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,525 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualize SARM Subtask Annotations
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
For each episode, it shows:
- A timeline with dashed vertical lines at subtask boundaries
- Sample frames from the episode at key points (start, middle, end of each subtask)
- Color-coded subtask segments
Usage:
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
"""
import argparse
import random
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from rich.console import Console
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import load_episodes
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def timestamp_to_seconds(timestamp: str) -> float:
"""Convert MM:SS or SS timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
else:
return int(parts[0])
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
"""
Load annotations from LeRobot dataset parquet files.
Reads subtask annotations from the episodes metadata parquet files.
"""
episodes_dataset = load_episodes(dataset_path)
if episodes_dataset is None or len(episodes_dataset) == 0:
return {}
# Check if subtask columns exist
if "subtask_names" not in episodes_dataset.column_names:
return {}
# Convert to pandas DataFrame for easier access
episodes_df = episodes_dataset.to_pandas()
annotations = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
# Reconstruct SubtaskAnnotation from stored data
subtasks = []
for i, name in enumerate(subtask_names):
# Convert seconds back to MM:SS format
start_sec = int(start_times[i])
end_sec = int(end_times[i])
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
subtasks.append(
Subtask(
name=name,
timestamps=Timestamp(start=start_str, end=end_str)
)
)
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
return annotations
# Color palette for subtasks (colorblind-friendly)
SUBTASK_COLORS = [
"#E69F00", # Orange
"#56B4E9", # Sky blue
"#009E73", # Bluish green
"#F0E442", # Yellow
"#0072B2", # Blue
"#D55E00", # Vermillion
"#CC79A7", # Reddish purple
"#999999", # Gray
]
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
"""Extract a single frame from video at given timestamp."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
# Set position to timestamp
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame = cap.read()
cap.release()
if ret:
# Convert BGR to RGB
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def visualize_episode(
episode_idx: int,
annotation,
video_path: Path,
video_start_timestamp: float,
video_end_timestamp: float,
fps: int,
output_path: Path,
video_key: str,
):
"""
Create visualization for a single episode.
Shows:
- Top row: Sample frames from the episode (one per subtask)
- Bottom: Timeline with subtask segments and boundary lines
"""
subtasks = annotation.subtasks
num_subtasks = len(subtasks)
if num_subtasks == 0:
print(f"No subtasks found for episode {episode_idx}")
return
# Calculate episode duration
episode_duration = video_end_timestamp - video_start_timestamp
# Extract sample frames - get frame from middle of each subtask
sample_frames = []
frame_timestamps = []
for subtask in subtasks:
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
mid_sec = (start_sec + end_sec) / 2
# Convert to video timestamp (add video_start_timestamp offset)
video_timestamp = video_start_timestamp + mid_sec
frame_timestamps.append(mid_sec)
frame = extract_frame_from_video(video_path, video_timestamp)
sample_frames.append(frame)
# Create figure
fig = plt.figure(figsize=(16, 10))
# Use a dark background for better contrast
fig.patch.set_facecolor('#1a1a2e')
# Calculate grid layout
# Top section: frames (variable number of columns based on subtasks)
# Bottom section: timeline
# Create gridspec
gs = fig.add_gridspec(
2, max(num_subtasks, 1),
height_ratios=[2, 1],
hspace=0.3,
wspace=0.1,
left=0.05, right=0.95,
top=0.88, bottom=0.1
)
# Add title
fig.suptitle(
f"Episode {episode_idx} - Subtask Annotations",
fontsize=18,
fontweight='bold',
color='white',
y=0.96
)
# Add subtitle with video info
fig.text(
0.5, 0.91,
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
ha='center',
fontsize=11,
color='#888888'
)
# Plot sample frames
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
ax = fig.add_subplot(gs[0, i])
ax.set_facecolor('#16213e')
if frame is not None:
ax.imshow(frame)
else:
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
fontsize=12, color='white', transform=ax.transAxes)
ax.set_title(
f"{subtask.name}",
fontsize=10,
fontweight='bold',
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
pad=8
)
ax.axis('off')
# Add frame timestamp below
ax.text(
0.5, -0.08,
f"t={frame_timestamps[i]:.1f}s",
ha='center',
fontsize=9,
color='#888888',
transform=ax.transAxes
)
# Create timeline subplot spanning all columns
ax_timeline = fig.add_subplot(gs[1, :])
ax_timeline.set_facecolor('#16213e')
# Get total duration from last subtask end time
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
# Draw subtask segments as colored bars
bar_height = 0.6
bar_y = 0.5
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
# Draw segment bar
rect = mpatches.FancyBboxPatch(
(start_sec, bar_y - bar_height/2),
end_sec - start_sec,
bar_height,
boxstyle="round,pad=0.02,rounding_size=0.1",
facecolor=color,
edgecolor='white',
linewidth=1.5,
alpha=0.85
)
ax_timeline.add_patch(rect)
# Add subtask label inside bar
mid_x = (start_sec + end_sec) / 2
duration = end_sec - start_sec
# Only add text if segment is wide enough
if duration > total_duration * 0.08:
ax_timeline.text(
mid_x, bar_y,
subtask.name,
ha='center', va='center',
fontsize=9,
fontweight='bold',
color='black' if i in [3] else 'white', # Yellow needs dark text
rotation=0 if duration > total_duration * 0.15 else 45
)
# Draw boundary lines (dashed vertical lines between subtasks)
boundary_times = []
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
# Add start boundary (except for first subtask at t=0)
if i == 0 and start_sec > 0:
boundary_times.append(start_sec)
elif i > 0:
boundary_times.append(start_sec)
# Add end boundary for last subtask
if i == len(subtasks) - 1:
boundary_times.append(end_sec)
# Draw dashed lines at boundaries
for t in boundary_times:
ax_timeline.axvline(
x=t,
ymin=0.1, ymax=0.9,
color='white',
linestyle='--',
linewidth=2,
alpha=0.9
)
# Add time label below line
ax_timeline.text(
t, 0.0,
f"{int(t//60):02d}:{int(t%60):02d}",
ha='center', va='top',
fontsize=8,
color='#cccccc'
)
# Add start line at t=0
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
# Configure timeline axes
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
ax_timeline.set_ylim(-0.3, 1.2)
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
ax_timeline.set_ylabel("")
# Style the axes
ax_timeline.spines['top'].set_visible(False)
ax_timeline.spines['right'].set_visible(False)
ax_timeline.spines['left'].set_visible(False)
ax_timeline.spines['bottom'].set_color('#444444')
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
# Add x-axis ticks at regular intervals
tick_interval = max(1, int(total_duration / 10))
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
# Add legend explaining line styles
legend_elements = [
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
]
ax_timeline.legend(
handles=legend_elements,
loc='upper right',
framealpha=0.3,
facecolor='#16213e',
edgecolor='#444444',
fontsize=9,
labelcolor='white'
)
# Save figure
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
plt.close()
return output_path
def main():
parser = argparse.ArgumentParser(
description="Visualize SARM subtask annotations",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID",
)
parser.add_argument(
"--num-episodes",
type=int,
default=5,
help="Number of random episodes to visualize (default: 5)",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Specific episode indices to visualize (overrides --num-episodes)",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Camera/video key to use. If not specified, uses first available.",
)
parser.add_argument(
"--output-dir",
type=str,
default="./subtask_viz",
help="Output directory for visualizations (default: ./subtask_viz)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
args = parser.parse_args()
console = Console()
# Set random seed if specified
if args.seed is not None:
random.seed(args.seed)
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
# Get video key
if args.video_key:
if args.video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
video_key = args.video_key
else:
video_key = dataset.meta.video_keys[0]
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
console.print(f"[cyan]FPS: {fps}[/cyan]")
# Load annotations
console.print(f"\n[cyan]Loading annotations...[/cyan]")
annotations = load_annotations_from_dataset(dataset.root)
if not annotations:
console.print("[red]Error: No annotations found in dataset[/red]")
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
return
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
# Determine which episodes to visualize
if args.episodes:
episode_indices = args.episodes
# Validate episodes exist
for ep in episode_indices:
if ep not in annotations:
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
episode_indices = [ep for ep in episode_indices if ep in annotations]
else:
# Random selection
available_episodes = list(annotations.keys())
num_to_select = min(args.num_episodes, len(available_episodes))
episode_indices = random.sample(available_episodes, num_to_select)
episode_indices.sort()
if not episode_indices:
console.print("[red]Error: No valid episodes to visualize[/red]")
return
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Generate visualizations
for ep_idx in episode_indices:
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
annotation = annotations[ep_idx]
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
console.print(f"[red]Video not found: {video_path}[/red]")
continue
# Get episode-specific timestamps within the video file
video_path_key = f"videos/{video_key}/from_timestamp"
video_path_key_to = f"videos/{video_key}/to_timestamp"
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
# Create visualization
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
try:
visualize_episode(
episode_idx=ep_idx,
annotation=annotation,
video_path=video_path,
video_start_timestamp=video_start_timestamp,
video_end_timestamp=video_end_timestamp,
fps=fps,
output_path=output_path,
video_key=video_key,
)
console.print(f"[green]✓ Saved: {output_path}[/green]")
except Exception as e:
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
# Print summary
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
console.print(f"[bold green]Visualization Complete![/bold green]")
console.print(f"[bold green]{'=' * 50}[/bold green]")
console.print(f"Output directory: {output_dir.absolute()}")
console.print(f"Episodes visualized: {len(episode_indices)}")
if __name__ == "__main__":
main()
-140
View File
@@ -1,140 +0,0 @@
import time
import numpy as np
import pinocchio as pin
from os.path import dirname
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
same_direction = {"joint_4", "gripper"}
idx = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
# joints to freeze
frozen = {"joint_6", "joint_7", "gripper"}
initial_pose = {}
def pos_deg(rob, obs):
out = {}
for side in ("left", "right"):
for m in getattr(rob, f"bus_{side}").motors:
k = f"{side}_{m}.pos"
if k in obs:
out[f"{side}_{m}"] = obs[k]
return out
def vel_rad(rob, obs):
out = {}
for side in ("left", "right"):
for m in getattr(rob, f"bus_{side}").motors:
k = f"{side}_{m}.vel"
out[f"{side}_{m}"] = np.deg2rad(obs.get(k, 0.0))
return out
def main():
cfg = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_bilateral",
manual_control=False,
)
rob = OpenArmsLeader(cfg)
rob.connect(calibrate=True)
urdf = "/home/yope/Documents/lerobot_g1_integration/openarm_description/openarm_bimanual_pybullet.urdf"
rob.pin_robot = pin.RobotWrapper.BuildFromURDF(urdf, dirname(urdf))
rob.pin_robot.data = rob.pin_robot.model.createData()
dt = 0.005
grav = 1.0
fric = 0.3
# capture initial pose to freeze selected joints later
obs0 = rob.get_action()
for side in ("left", "right"):
for m in getattr(rob, f"bus_{side}").motors:
key = f"{side}_{m}.pos"
if key in obs0 and m in frozen:
initial_pose[f"{side}_{m}"] = obs0[key]
try:
while True:
obs = rob.get_action()
pdeg = pos_deg(rob, obs)
prad = {k: np.deg2rad(v) for k, v in pdeg.items()}
vrad = vel_rad(rob, obs)
tau_g = rob._gravity_from_q(prad)
tau_f = rob._friction_from_velocity(vrad, friction_scale=fric)
# bilateral midpoint calculation
cmd = {}
for m in rob.bus_right.motors:
kl = f"left_{m}.pos"
kr = f"right_{m}.pos"
if kl not in obs or kr not in obs:
continue
ql = obs[kl]
qr = obs[kr]
if m in same_direction:
qmid = 0.5 * (ql + qr)
else:
qmid = 0.5 * (ql - qr)
# assign midpoint for both
cmd[f"left_{m}"] = qmid
cmd[f"right_{m}"] = qmid if m in same_direction else -qmid
# override midpoint with frozen values
for key, val in initial_pose.items():
cmd[key] = val
# single mit control call
for side in ("left", "right"):
bus = getattr(rob, f"bus_{side}")
for m in bus.motors:
base_key = f"{side}_{m}"
kp = float(cfg.position_kp[idx[m]])
kd = float(cfg.position_kd[idx[m]])
torque = tau_g.get(base_key, 0.0) * grav + tau_f.get(base_key, 0.0)
pos_cmd = cmd.get(base_key, pdeg.get(base_key, 0.0))
bus._mit_control(
motor=m,
kp=kp,
kd=kd,
position_degrees=pos_cmd,
velocity_deg_per_sec=0.0,
torque=torque,
)
time.sleep(dt)
except KeyboardInterrupt:
pass
rob.bus_left.disable_torque()
rob.bus_right.disable_torque()
rob.disconnect()
if __name__ == "__main__":
main()
@@ -1,416 +0,0 @@
#!/usr/bin/env python3
"""
Comprehensive debug script for OpenArms CAN FD communication.
Tests all 4 CAN interfaces with CAN FD support.
"""
import can
import time
import sys
import subprocess
def check_can_interface(port):
"""Check if CAN interface is UP and configured."""
try:
result = subprocess.run(['ip', 'link', 'show', port],
capture_output=True, text=True)
if result.returncode != 0:
return False, "Interface not found", None
output = result.stdout
if 'UP' not in output:
return False, "Interface is DOWN", None
# Check if CAN FD is enabled
is_fd = 'fd on' in output.lower() or 'canfd' in output.lower()
return True, "Interface is UP", is_fd
except FileNotFoundError:
return None, "Cannot check (ip command not found)", None
def test_motor_on_interface(bus, motor_id, timeout=2.0, use_fd=False):
"""
Test a single motor and return all responses.
Returns:
list of (arbitration_id, data) tuples for all responses received
"""
# Send enable command
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=use_fd
)
try:
bus.send(enable_msg)
except Exception as e:
return None, f"Send error: {e}"
# Listen for responses
responses = []
start_time = time.time()
while time.time() - start_time < timeout:
msg = bus.recv(timeout=0.1)
if msg:
responses.append((msg.arbitration_id, msg.data, msg.is_fd if hasattr(msg, 'is_fd') else False))
# Send disable command
disable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD],
is_extended_id=False,
is_fd=use_fd
)
try:
bus.send(disable_msg)
except:
pass
return responses, None
def test_interface(port, interface_type="socketcan", use_can_fd=True):
"""Test all 8 motors on a single CAN interface."""
results = {
'interface': port,
'status': None,
'is_fd': use_can_fd,
'motors': {}
}
# Check interface status
status_ok, status_msg, interface_has_fd = check_can_interface(port)
if interface_has_fd is not None:
results['interface_fd_enabled'] = interface_has_fd
if use_can_fd and not interface_has_fd:
status_msg += " (CAN FD NOT enabled on interface!)"
elif interface_has_fd:
status_msg += " (CAN FD enabled)"
results['status'] = status_msg
if status_ok is False:
return results
# Try to connect
try:
if use_can_fd:
print(f" Connecting to {port} with CAN FD (1 Mbps / 5 Mbps)...")
bus = can.interface.Bus(
channel=port,
interface=interface_type,
bitrate=1000000,
data_bitrate=5000000,
fd=True
)
else:
print(f" Connecting to {port} with CAN 2.0 (1 Mbps)...")
bus = can.interface.Bus(
channel=port,
interface=interface_type,
bitrate=1000000
)
except Exception as e:
results['status'] = f"Connection failed: {e}"
return results
try:
# Clear any pending messages
while bus.recv(timeout=0.01):
pass
# Test each motor (0x01 to 0x08)
for motor_id in range(0x01, 0x09):
responses, error = test_motor_on_interface(bus, motor_id, timeout=1.0, use_fd=use_can_fd)
if error:
results['motors'][motor_id] = {'error': error}
elif responses:
results['motors'][motor_id] = {
'found': True,
'responses': responses
}
else:
results['motors'][motor_id] = {
'found': False,
'responses': []
}
time.sleep(0.05) # Small delay between motors
finally:
bus.shutdown()
return results
def print_results(all_results):
"""Print formatted results for all interfaces."""
print("SUMMARY - Motors Found on Each Interface")
motor_names = {
0x01: "joint_1 (Shoulder pan)",
0x02: "joint_2 (Shoulder lift)",
0x03: "joint_3 (Shoulder rotation)",
0x04: "joint_4 (Elbow flex)",
0x05: "joint_5 (Wrist roll)",
0x06: "joint_6 (Wrist pitch)",
0x07: "joint_7 (Wrist rotation)",
0x08: "gripper",
}
total_found = 0
for result in all_results:
interface = result['interface']
status = result['status']
print(f"{interface}: {status}")
if result.get('is_fd'):
print(f" Mode: CAN FD")
else:
print(f" Mode: CAN 2.0")
if 'Connection failed' in status or 'DOWN' in status:
print(f" ⚠ Cannot test {interface}")
continue
motors_found = 0
for motor_id in range(0x01, 0x09):
motor_data = result['motors'].get(motor_id, {})
motor_name = motor_names.get(motor_id, "Unknown")
if motor_data.get('error'):
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {motor_data['error']}")
elif motor_data.get('found'):
motors_found += 1
total_found += 1
responses = motor_data['responses']
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND")
for resp_id, data, is_fd in responses:
data_hex = data.hex()
fd_flag = " [FD]" if is_fd else " [2.0]"
print(f" → Response from 0x{resp_id:02X}{fd_flag}: {data_hex}")
else:
print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response")
print(f"\n Summary: {motors_found}/8 motors found on {interface}")
# Overall summary
print("OVERALL SUMMARY")
print(f"Total motors found across all interfaces: {total_found}")
# Analyze configuration
print("DIAGNOSIS")
for result in all_results:
interface = result['interface']
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
if motors_found == 0:
print(f"\n{interface}: NO MOTORS FOUND")
print(" Possible issues:")
print(" 1. CAN FD mode mismatch (interface vs motor configuration)")
print(" 2. Missing 120Ω termination resistors at BOTH cable ends")
print(" 3. Motor timeout parameter set incorrectly (should NOT be 0)")
print(" 4. CANH/CANL wiring issue")
print(" 5. Cable too long (>40m for CAN FD at 5Mbps)")
# Check FD mismatch
if result.get('is_fd') and not result.get('interface_fd_enabled'):
print(" ⚠️ CRITICAL: Trying CAN FD but interface NOT configured for FD!")
print(f" Fix: sudo ip link set {interface} type can bitrate 1000000 dbitrate 5000000 fd on")
elif motors_found < 8:
print(f"\n{interface}: Only {motors_found}/8 motors responding")
print(" Check power and connections for missing motors")
else:
print(f"\n{interface}: All 8 motors responding correctly!")
# Check for unexpected response IDs
print("RESPONSE ID ANALYSIS")
for result in all_results:
interface = result['interface']
unexpected = []
for motor_id, motor_data in result['motors'].items():
if motor_data.get('found'):
expected_id = motor_id + 0x10
actual_ids = [resp[0] for resp in motor_data['responses']]
if expected_id not in actual_ids:
unexpected.append((motor_id, actual_ids))
if unexpected:
print(f"\n{interface}: Unexpected response IDs detected")
for motor_id, actual_ids in unexpected:
expected_id = motor_id + 0x10
print(f" Motor 0x{motor_id:02X}: Expected 0x{expected_id:02X}, "
f"got {[f'0x{id:02X}' for id in actual_ids]}")
print(" → Motor Master IDs need reconfiguration")
else:
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
if motors_found > 0:
print(f"\n{interface}: All responding motors use correct IDs")
def test_communication_speed(interface, motor_id, num_iterations=100):
"""
Test communication speed with a motor.
Returns:
tuple: (hz, avg_latency_ms) or (None, None) if test failed
"""
try:
# Connect to interface
bus = can.interface.Bus(
channel=interface,
interface="socketcan",
bitrate=1000000,
data_bitrate=5000000,
fd=True
)
# Send refresh commands and measure round-trip time
latencies = []
successful = 0
for _ in range(num_iterations):
start = time.perf_counter()
# Send enable command (lightweight operation)
enable_msg = can.Message(
arbitration_id=motor_id,
data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC],
is_extended_id=False,
is_fd=True
)
bus.send(enable_msg)
# Wait for response
msg = bus.recv(timeout=0.1)
if msg:
latency = (time.perf_counter() - start) * 1000 # Convert to ms
latencies.append(latency)
successful += 1
bus.shutdown()
if successful > 0:
avg_latency = sum(latencies) / len(latencies)
hz = 1000.0 / avg_latency if avg_latency > 0 else 0
return hz, avg_latency
return None, None
except Exception as e:
print(f" Speed test error: {e}")
return None, None
def main():
"""Main function to test all CAN interfaces with CAN FD."""
print("\nThis will test all 4 CAN interfaces (can0-can3) with CAN FD")
print("Testing motors 0x01-0x08 on each interface")
print()
print("Make sure:")
print(" ✓ Motors are powered (24V)")
print(" ✓ CAN interfaces configured with FD mode:")
print(" ./examples/openarms/setup_can.sh")
print(" ✓ Motor 'timeout' parameter NOT set to 0 (use Damiao tools)")
print(" ✓ CAN wiring includes 120Ω termination at BOTH ends")
print()
input("Press ENTER to start testing...")
# Test all 4 interfaces with CAN FD
all_results = []
for i in range(4):
interface = f"can{i}"
print(f"Testing {interface}...")
result = test_interface(interface, use_can_fd=True)
all_results.append(result)
# Quick status
if 'Connection failed' in result['status'] or 'DOWN' in result['status']:
print(f"{interface}: {result['status']}")
else:
motors_found = sum(1 for m in result['motors'].values() if m.get('found'))
print(f" {interface}: {motors_found}/8 motors found")
time.sleep(0.2)
# Print detailed results
print_results(all_results)
print("Testing Complete!")
all_found = sum(sum(1 for m in r['motors'].values() if m.get('found')) for r in all_results)
if all_found == 0:
print("\n⚠️ CRITICAL: No motors found on any interface!")
print("\nTop issues to check:")
print(" 1. Motor 'timeout' parameter (use Damiao tools to set > 0)")
print(" 2. CAN FD not enabled (run ./examples/openarms/setup_can.sh)")
print(" 3. Missing termination resistors")
print("\nTry:")
print(" a) Check motor parameters with Damiao Debugging Tools")
print(" b) Verify CAN FD is enabled: ip -d link show can0 | grep fd")
print(" c) Run setup script: ./examples/openarms/setup_can.sh")
else:
# Run speed test on interfaces with motors
print("COMMUNICATION SPEED TEST")
print("\nTesting maximum communication frequency...")
for result in all_results:
interface = result['interface']
# Find first responding motor
responding_motor = None
for motor_id, motor_data in result['motors'].items():
if motor_data.get('found'):
responding_motor = motor_id
break
if responding_motor:
print(f"\n{interface}: Testing with motor 0x{responding_motor:02X}...")
hz, latency = test_communication_speed(interface, responding_motor, num_iterations=100)
if hz:
print(f" ✓ Max frequency: {hz:.1f} Hz")
print(f" ✓ Avg latency: {latency:.2f} ms")
print(f" ✓ Commands per second: ~{int(hz)}")
else:
print(f" ✗ Speed test failed")
else:
print(f"\n{interface}: No motors found, skipping speed test")
print()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\nTesting interrupted by user.")
sys.exit(1)
except Exception as e:
print(f"\nUnexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
-360
View File
@@ -1,360 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Policy Evaluation
Evaluates a trained policy on the OpenArms robot by running inference and recording
the evaluation episodes to a dataset. Supports optional leader arm for manual resets.
Example usage:
python examples/openarms/evaluate.py
"""
import time
from pathlib import Path
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.processor import make_default_processors
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval7" # TODO: Replace with your eval dataset name
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task, this should match!!
NUM_EPISODES = 1
FPS = 30
EPISODE_TIME_SEC = 300
RESET_TIME_SEC = 60
# Robot CAN interfaces
FOLLOWER_LEFT_PORT = "can0"
FOLLOWER_RIGHT_PORT = "can1"
# If enabled, you can manually reset the environment between evaluation episodes
USE_LEADER_FOR_RESETS = True # Set to False if you don't want to use leader
LEADER_LEFT_PORT = "can2"
LEADER_RIGHT_PORT = "can3"
# Camera configuration
CAMERA_CONFIG = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=FPS),
}
def main():
"""Main evaluation function."""
print("OpenArms Policy Evaluation")
print(f"\nModel: {HF_MODEL_ID}")
print(f"Evaluation Dataset: {HF_EVAL_DATASET_ID}")
print(f"Task: {TASK_DESCRIPTION}")
print(f"Episodes: {NUM_EPISODES}")
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
print(f"Reset Duration: {RESET_TIME_SEC}s")
print(f"Use Leader for Resets: {USE_LEADER_FOR_RESETS}")
follower_config = OpenArmsFollowerConfig(
port_left=FOLLOWER_LEFT_PORT,
port_right=FOLLOWER_RIGHT_PORT,
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=CAMERA_CONFIG,
)
follower = OpenArmsFollower(follower_config)
follower.connect(calibrate=False)
if not follower.is_connected:
raise RuntimeError("Follower robot failed to connect!")
leader = None
if USE_LEADER_FOR_RESETS:
leader_config = OpenArmsLeaderConfig(
port_left=LEADER_LEFT_PORT,
port_right=LEADER_RIGHT_PORT,
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
leader = OpenArmsLeader(leader_config)
leader.connect(calibrate=False)
if not leader.is_connected:
raise RuntimeError("Leader robot failed to connect!")
# Enable gravity compensation
if leader.pin_robot is not None:
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print(f"Leader connected with gravity compensation ({LEADER_LEFT_PORT}, {LEADER_RIGHT_PORT})")
else:
print(f"Leader connected but gravity compensation unavailable (no URDF)")
# Build default processors for action and observation
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Build dataset features from robot features and processors
# For actions, only include positions (no velocity or torque)
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=follower.observation_features),
use_videos=True,
),
)
# Check if dataset already exists
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
if dataset_path.exists():
print(f"Evaluation dataset already exists at: {dataset_path}")
print("This will append new episodes to the existing dataset.")
choice = input(" Continue? (y/n): ").strip().lower()
if choice != 'y':
print(" Aborting evaluation.")
follower.disconnect()
if leader:
leader.disconnect()
return
# Create dataset
dataset = LeRobotDataset.create(
repo_id=HF_EVAL_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_processes=0,
image_writer_threads=12,
)
# Load policy config from pretrained model and create policy using factory
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
policy = make_policy(policy_config, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)}
},
)
print(f"\nRunning evaluation...")
# Initialize keyboard listener and visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="openarms_evaluation")
episode_idx = 0
try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
print(f"\nRunning inference for episode {episode_idx + 1}...")
# Run inference with policy
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Handle re-recording
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
print(f"Saving episode {episode_idx + 1} ({dataset.episode_buffer['size']} frames)...")
dataset.save_episode()
episode_idx += 1
# Reset environment between episodes (if not last episode)
if not events["stop_recording"] and episode_idx < NUM_EPISODES:
if USE_LEADER_FOR_RESETS and leader:
log_say("Reset the environment using leader arms")
print(f"\nManual reset period ({RESET_TIME_SEC}s)...")
# Use leader for manual reset with gravity compensation
import numpy as np
dt = 1 / FPS
reset_start_time = time.perf_counter()
while time.perf_counter() - reset_start_time < RESET_TIME_SEC:
if events["exit_early"] or events["stop_recording"]:
break
loop_start = time.perf_counter()
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity and friction torques
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=1.0
)
# Combine torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply compensation
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor, kp=0.0, kd=kd,
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor, kp=0.0, kd=kd,
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower
follower_action = {}
for joint in leader_positions_deg.keys():
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
if follower_action:
follower.send_action(follower_action)
# Maintain loop rate
loop_duration = time.perf_counter() - loop_start
sleep_time = dt - loop_duration
if sleep_time > 0:
time.sleep(sleep_time)
print("Reset complete")
else:
log_say("Waiting for manual reset")
print(f"Manually reset the environment and press ENTER to continue")
input("Press ENTER when ready...")
print(f"Evaluation complete! {episode_idx} episodes recorded")
log_say("Evaluation complete", blocking=True)
except KeyboardInterrupt:
print("\n\nEvaluation interrupted by user")
finally:
if leader:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
if listener is not None:
listener.stop()
dataset.finalize()
print("\nUploading to Hugging Face Hub...")
dataset.push_to_hub(private=True)
if __name__ == "__main__":
main()
-216
View File
@@ -1,216 +0,0 @@
import time
import numpy as np
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
# Friction model parameters from OpenArms config/follower.yaml
# τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
# For 8 motors: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
FRICTION_PARAMS = {
"Fc": [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512], # Coulomb friction [Nm]
"k": [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000], # tanh steepness
"Fv": [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084], # Viscous friction [Nm·s/rad]
"Fo": [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050], # Offset torque [Nm]
}
# Constants from OpenArms C++ implementation
AMP_TMP = 1.0
COEF_TMP = 0.1
FRICTION_SCALE = 1.0 # OpenArms C++ uses 0.3 factor in unilateral mode
DAMPING_KD = [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1] # Damping gains for stability
def compute_friction_torque(velocity_rad_per_sec: float, motor_index: int) -> float:
"""
Compute friction torque for a single motor using the tanh friction model.
Args:
velocity_rad_per_sec: Angular velocity in rad/s
motor_index: Index of the motor (0-7)
Returns:
Friction torque in N·m (scaled for stability)
"""
Fc = FRICTION_PARAMS["Fc"][motor_index]
k = FRICTION_PARAMS["k"][motor_index]
Fv = FRICTION_PARAMS["Fv"][motor_index]
Fo = FRICTION_PARAMS["Fo"][motor_index]
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
friction_torque = (
AMP_TMP * Fc * np.tanh(COEF_TMP * k * velocity_rad_per_sec) +
Fv * velocity_rad_per_sec +
Fo
)
# Scale down friction compensation for stability at lower control rates
# (OpenArms C++ uses 0.3 factor in unilateral mode)!!
friction_torque *= FRICTION_SCALE
return friction_torque
def main() -> None:
config = OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0,
)
print("Initializing robot...")
follower = OpenArmsFollower(config)
follower.connect(calibrate=True)
print(f"Applying friction compensation")
print(" 1. Support the arm before starting")
print(" 2. The arm will be held in place by friction compensation")
print(" 3. You should be able to move it with gentle force")
print("\nPress ENTER when ready to start...")
input()
print(f"✓ Motors enabled")
print("\nStarting friction compensation loop...")
print("Press Ctrl+C to stop\n")
loop_times = []
last_print_time = time.perf_counter()
# Motor name to index mapping
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
try:
while True:
loop_start = time.perf_counter()
# Get current joint positions and velocities from robot
obs = follower.get_observation()
# Extract velocities in degrees per second
velocities_deg_per_sec = {}
positions_deg = {}
for motor in follower.bus_right.motors:
vel_key = f"right_{motor}.vel"
pos_key = f"right_{motor}.pos"
if vel_key in obs:
velocities_deg_per_sec[f"right_{motor}"] = obs[vel_key]
if pos_key in obs:
positions_deg[f"right_{motor}"] = obs[pos_key]
for motor in follower.bus_left.motors:
vel_key = f"left_{motor}.vel"
pos_key = f"left_{motor}.pos"
if vel_key in obs:
velocities_deg_per_sec[f"left_{motor}"] = obs[vel_key]
if pos_key in obs:
positions_deg[f"left_{motor}"] = obs[pos_key]
# Convert velocities to rad/s and compute friction torques
friction_torques_nm = {}
for motor_full_name, velocity_deg_per_sec in velocities_deg_per_sec.items():
# Extract motor name without arm prefix
if motor_full_name.startswith("right_"):
motor_name = motor_full_name.removeprefix("right_")
elif motor_full_name.startswith("left_"):
motor_name = motor_full_name.removeprefix("left_")
else:
continue
# Get motor index for friction parameters
motor_index = motor_name_to_index.get(motor_name, 0)
# Convert velocity to rad/s
velocity_rad_per_sec = np.deg2rad(velocity_deg_per_sec)
# Compute friction torque
friction_torque = compute_friction_torque(velocity_rad_per_sec, motor_index)
friction_torques_nm[motor_full_name] = friction_torque
# Apply friction compensation to right arm (all joints INCLUDING gripper)
for motor in follower.bus_right.motors:
full_name = f"right_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = friction_torques_nm.get(full_name, 0.0)
# Get motor index for damping gain
motor_index = motor_name_to_index.get(motor, 0)
kd = DAMPING_KD[motor_index]
# Send MIT control command with friction compensation + damping
follower.bus_right._mit_control(
motor=motor,
kp=0.0, # No position control
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Apply friction compensation to left arm (all joints INCLUDING gripper)
for motor in follower.bus_left.motors:
full_name = f"left_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = friction_torques_nm.get(full_name, 0.0)
# Get motor index for damping gain
motor_index = motor_name_to_index.get(motor, 0)
kd = DAMPING_KD[motor_index]
# Send MIT control command with friction compensation + damping
follower.bus_left._mit_control(
motor=motor,
kp=0.0, # No position control
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print status every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz")
loop_times = []
last_print_time = loop_end
time.sleep(0.001)
except KeyboardInterrupt:
print("\n\nStopping friction compensation...")
finally:
print("\nDisabling all motors and disconnecting...")
follower.bus_right.disable_torque()
follower.bus_left.disable_torque()
time.sleep(0.1)
follower.disconnect()
print("✓ Safe shutdown complete")
if __name__ == "__main__":
main()
-139
View File
@@ -1,139 +0,0 @@
import time
import numpy as np
import pinocchio as pin
from os.path import join, dirname, exists, expanduser
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
def main() -> None:
config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
print("Initializing robot...")
follower = OpenArmsLeader(config)
follower.connect(calibrate=True)
# Load URDF for Pinocchio dynamics
urdf_path = "/home/yope/Documents/lerobot_g1_integration/openarm_description/openarm_bimanual_pybullet.urdf"
pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, dirname(urdf_path))
pin_robot.data = pin_robot.model.createData()
print(f"✓ Loaded Pinocchio model with {pin_robot.nq} DoFs")
follower.pin_robot = pin_robot
print(f"Applying gravity compensation")
print(" 1. Support the arm before starting")
print(" 2. The arm will be held in place by gravity compensation")
print(" 3. You should be able to move it with gentle force")
print("\nPress ENTER when ready to start...")
input()
print(f"✓ Motors enabled")
print("\nStarting gravity compensation loop...")
print("Press Ctrl+C to stop\n")
loop_times = []
last_print_time = time.perf_counter()
try:
while True:
loop_start = time.perf_counter()
# Get current joint positions from robot
obs = follower.get_action()
# Extract positions in degrees
positions_deg = {}
for motor in follower.bus_right.motors:
key = f"right_{motor}.pos"
if key in obs:
positions_deg[f"right_{motor}"] = obs[key]
for motor in follower.bus_left.motors:
key = f"left_{motor}.pos"
if key in obs:
positions_deg[f"left_{motor}"] = obs[key]
# Convert to radians and calculate gravity torques
# Use the built-in method from OpenArmsFollower
positions_rad = {k: np.deg2rad(v) for k, v in positions_deg.items()}
torques_nm = follower._gravity_from_q(positions_rad)
# Apply gravity compensation to right arm (all joints except gripper)
for motor in follower.bus_right.motors:
full_name = f"right_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = torques_nm.get(full_name, 0.0)
# Send MIT control command with gravity compensation torque
follower.bus_right._mit_control(
motor=motor,
kp=0.0, # No position control
kd=0.0, # No velocity damping
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Apply gravity compensation to left arm (all joints except gripper)
for motor in follower.bus_left.motors:
full_name = f"left_{motor}"
position = positions_deg.get(full_name, 0.0)
torque = torques_nm.get(full_name, 0.0)
# Send MIT control command with gravity compensation torque
follower.bus_left._mit_control(
motor=motor,
kp=0.0, # No position control
kd=0.0, # No velocity damping
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque
)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print status every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
loop_times = []
last_print_time = loop_end
time.sleep(0.005)
except KeyboardInterrupt:
print("\n\nStopping gravity compensation...")
finally:
print("\nDisabling all motors and disconnecting...")
follower.bus_right.disable_torque()
follower.bus_left.disable_torque()
time.sleep(0.1)
follower.disconnect()
print("✓ Safe shutdown complete")
if __name__ == "__main__":
main()
@@ -1,618 +0,0 @@
<?xml version='1.0' encoding='utf-8'?>
<robot name="openarm">
<link name="world" />
<joint name="openarm_body_world_joint" type="fixed">
<parent link="world" />
<child link="openarm_body_link0" />
<origin rpy="0 0 0" xyz="0 0 0" />
</joint>
<link name="openarm_body_link0">
<visual name="openarm_body_link0_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/body/v10/visual/body_link0.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_body_link0_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/body/v10/collision/body_link0_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<mass value="13.89" />
<inertia ixx="1.653" ixy="0.0" ixz="0.0" iyy="1.653" iyz="0.0" izz="0.051" />
</inertial>
</link>
<joint name="openarm_left_openarm_body_link0_joint" type="fixed">
<parent link="openarm_body_link0" />
<child link="openarm_left_link0" />
<origin rpy="-1.5708 0 0" xyz="0.0 0.031 0.698" />
</joint>
<link name="openarm_left_link0">
<visual name="openarm_left_link0_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link0_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 -0.0001580207020448382 0.03076860287587199" />
<mass value="1.1432284943239561" />
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
</inertial>
</link>
<link name="openarm_left_link1">
<visual name="openarm_left_link1_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link1_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 -3.319987657026362e-05 0.05395284380736254" />
<mass value="1.1416684646202298" />
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
</inertial>
</link>
<joint name="openarm_left_joint1" type="revolute">
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
<parent link="openarm_left_link0" />
<child link="openarm_left_link1" />
<axis xyz="0 0 1" />
<limit effort="40" lower="-3.490659" upper="1.3962629999999998" velocity="16.754666" />
</joint>
<link name="openarm_left_link2">
<visual name="openarm_left_link2_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link2_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 2.0145102027597523e-08 0.03256649300522363" />
<mass value="0.2775092746011571" />
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
</inertial>
</link>
<joint name="openarm_left_joint2" type="revolute">
<origin rpy="-1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
<parent link="openarm_left_link1" />
<child link="openarm_left_link2" />
<axis xyz="-1 0 0" />
<limit effort="40" lower="-3.3161253267948965" upper="0.17453267320510335" velocity="16.754666" />
</joint>
<link name="openarm_left_link3">
<visual name="openarm_left_link3_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link3_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 -0.0005549085042607548 0.09047470545721961" />
<mass value="1.073863338202347" />
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
</inertial>
</link>
<joint name="openarm_left_joint3" type="revolute">
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
<parent link="openarm_left_link2" />
<child link="openarm_left_link3" />
<axis xyz="0 0 1" />
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
</joint>
<link name="openarm_left_link4">
<visual name="openarm_left_link4_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link4_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
<mass value="0.6348534566833373" />
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
</inertial>
</link>
<joint name="openarm_left_joint4" type="revolute">
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
<parent link="openarm_left_link3" />
<child link="openarm_left_link4" />
<axis xyz="0 1 0" />
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
</joint>
<link name="openarm_left_link5">
<visual name="openarm_left_link5_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link5_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 -0.0008866902457326625 0.043079803024980934" />
<mass value="0.6156588026168502" />
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
</inertial>
</link>
<joint name="openarm_left_joint5" type="revolute">
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
<parent link="openarm_left_link4" />
<child link="openarm_left_link5" />
<axis xyz="0 0 1" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<link name="openarm_left_link6">
<visual name="openarm_left_link6_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link6_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 -0.00033230528343419053 -9.498374522309838e-05" />
<mass value="0.475202773187987" />
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
</inertial>
</link>
<joint name="openarm_left_joint6" type="revolute">
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
<parent link="openarm_left_link5" />
<child link="openarm_left_link6" />
<axis xyz="1 0 0" />
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
</joint>
<link name="openarm_left_link7">
<visual name="openarm_left_link7_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_link7_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 -0.01266175250761268 0.06951945409987448" />
<mass value="0.4659771327380578" />
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
</inertial>
</link>
<joint name="openarm_left_joint7" type="revolute">
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
<parent link="openarm_left_link6" />
<child link="openarm_left_link7" />
<axis xyz="0 -1 0" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<joint name="openarm_right_openarm_body_link0_joint" type="fixed">
<parent link="openarm_body_link0" />
<child link="openarm_right_link0" />
<origin rpy="1.5708 0 0" xyz="0.0 -0.031 0.698" />
</joint>
<link name="openarm_right_link0">
<visual name="openarm_right_link0_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link0_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 0.0001580207020448382 0.03076860287587199" />
<mass value="1.1432284943239561" />
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
</inertial>
</link>
<link name="openarm_right_link1">
<visual name="openarm_right_link1_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link1_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 3.319987657026362e-05 0.05395284380736254" />
<mass value="1.1416684646202298" />
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
</inertial>
</link>
<joint name="openarm_right_joint1" type="revolute">
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
<parent link="openarm_right_link0" />
<child link="openarm_right_link1" />
<axis xyz="0 0 1" />
<limit effort="40" lower="-1.396263" upper="3.490659" velocity="16.754666" />
</joint>
<link name="openarm_right_link2">
<visual name="openarm_right_link2_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link2_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 -2.0145102027597523e-08 0.03256649300522363" />
<mass value="0.2775092746011571" />
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
</inertial>
</link>
<joint name="openarm_right_joint2" type="revolute">
<origin rpy="1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
<parent link="openarm_right_link1" />
<child link="openarm_right_link2" />
<axis xyz="-1 0 0" />
<limit effort="40" lower="-0.17453267320510335" upper="3.3161253267948965" velocity="16.754666" />
</joint>
<link name="openarm_right_link3">
<visual name="openarm_right_link3_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link3_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 0.0005549085042607548 0.09047470545721961" />
<mass value="1.073863338202347" />
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
</inertial>
</link>
<joint name="openarm_right_joint3" type="revolute">
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
<parent link="openarm_right_link2" />
<child link="openarm_right_link3" />
<axis xyz="0 0 1" />
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
</joint>
<link name="openarm_right_link4">
<visual name="openarm_right_link4_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link4_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
<mass value="0.6348534566833373" />
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
</inertial>
</link>
<joint name="openarm_right_joint4" type="revolute">
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
<parent link="openarm_right_link3" />
<child link="openarm_right_link4" />
<axis xyz="0 1 0" />
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
</joint>
<link name="openarm_right_link5">
<visual name="openarm_right_link5_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link5_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 0.0008866902457326625 0.043079803024980934" />
<mass value="0.6156588026168502" />
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
</inertial>
</link>
<joint name="openarm_right_joint5" type="revolute">
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
<parent link="openarm_right_link4" />
<child link="openarm_right_link5" />
<axis xyz="0 0 1" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<link name="openarm_right_link6">
<visual name="openarm_right_link6_visual">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link6_collision">
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 0.00033230528343419053 -9.498374522309838e-05" />
<mass value="0.475202773187987" />
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
</inertial>
</link>
<joint name="openarm_right_joint6" type="revolute">
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
<parent link="openarm_right_link5" />
<child link="openarm_right_link6" />
<axis xyz="1 0 0" />
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
</joint>
<link name="openarm_right_link7">
<visual name="openarm_right_link7_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_link7_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
<geometry>
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 0.01266175250761268 0.06951945409987448" />
<mass value="0.4659771327380578" />
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
</inertial>
</link>
<joint name="openarm_right_joint7" type="revolute">
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
<parent link="openarm_right_link6" />
<child link="openarm_right_link7" />
<axis xyz="0 1 0" />
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
</joint>
<link name="openarm_left_hand">
<visual name="openarm_left_hand_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_hand_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
<mass value="0.35" />
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
</inertial>
</link>
<joint name="left_openarm_hand_joint" type="fixed">
<parent link="openarm_left_link7" />
<child link="openarm_left_hand" />
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
</joint>
<link name="openarm_left_hand_tcp">
<inertial>
<origin xyz="0 0 0" rpy="0 0 0" />
<mass value="0.001" />
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
</inertial>
</link>
<joint name="openarm_left_hand_tcp_joint" type="fixed">
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
<parent link="openarm_left_hand" />
<child link="openarm_left_hand_tcp" />
</joint>
<link name="openarm_left_left_finger">
<visual name="openarm_left_left_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_left_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<link name="openarm_left_right_finger">
<visual name="openarm_left_right_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_left_right_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<joint name="openarm_left_finger_joint1" type="prismatic">
<parent link="openarm_left_hand" />
<child link="openarm_left_right_finger" />
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
<axis xyz="0 -1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
</joint>
<joint name="openarm_left_finger_joint2" type="prismatic">
<parent link="openarm_left_hand" />
<child link="openarm_left_left_finger" />
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
<axis xyz="0 1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
<mimic joint="openarm_left_finger_joint1" />
</joint>
<link name="openarm_right_hand">
<visual name="openarm_right_hand_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_hand_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
<mass value="0.35" />
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
</inertial>
</link>
<joint name="right_openarm_hand_joint" type="fixed">
<parent link="openarm_right_link7" />
<child link="openarm_right_hand" />
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
</joint>
<link name="openarm_right_hand_tcp">
<inertial>
<origin xyz="0 0 0" rpy="0 0 0" />
<mass value="0.001" />
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
</inertial>
</link>
<joint name="openarm_right_hand_tcp_joint" type="fixed">
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
<parent link="openarm_right_hand" />
<child link="openarm_right_hand_tcp" />
</joint>
<link name="openarm_right_left_finger">
<visual name="openarm_right_left_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_left_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<link name="openarm_right_right_finger">
<visual name="openarm_right_right_finger_visual">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</visual>
<collision name="openarm_right_right_finger_collision">
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
<geometry>
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
</geometry>
</collision>
<inertial>
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
<mass value="0.03602545343277134" />
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
</inertial>
</link>
<joint name="openarm_right_finger_joint1" type="prismatic">
<parent link="openarm_right_hand" />
<child link="openarm_right_right_finger" />
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
<axis xyz="0 -1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
</joint>
<joint name="openarm_right_finger_joint2" type="prismatic">
<parent link="openarm_right_hand" />
<child link="openarm_right_left_finger" />
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
<axis xyz="0 1 0" />
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
<mimic joint="openarm_right_finger_joint1" />
</joint>
</robot>
@@ -1,395 +0,0 @@
"""
OpenArms Dataset Recording with Gravity + Friction Compensation
Records a dataset using OpenArms follower robot with leader teleoperator.
Leader arms have gravity and friction compensation for weightless, easy movement.
Includes 3 cameras: left wrist, right wrist, and base camera.
Uses the same compensation approach as teleop_with_compensation.py
"""
import shutil
import time
from pathlib import Path
import numpy as np
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
# Recording parameters
NUM_EPISODES = 1
FPS = 30
EPISODE_TIME_SEC = 600
RESET_TIME_SEC = 120
TASK_DESCRIPTION = "OpenArms task description"
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
FRICTION_SCALE = 1.0
def record_loop_with_compensation(
robot,
leader,
events,
fps,
dataset,
dataset_features,
control_time_s,
single_task,
display_data=True,
):
"""
Custom record loop that applies gravity + friction compensation to leader.
Based on record_loop but with integrated compensation.
"""
dt = 1 / fps
episode_start_time = time.perf_counter()
# All joints (both arms)
all_joints = []
for motor in leader.bus_right.motors:
all_joints.append(f"right_{motor}")
for motor in leader.bus_left.motors:
all_joints.append(f"left_{motor}")
while True:
loop_start = time.perf_counter()
elapsed = loop_start - episode_start_time
# Check if we should exit
if elapsed >= control_time_s or events["exit_early"] or events["stop_recording"]:
break
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities in degrees
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity torques for leader using built-in method
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
# Calculate friction torques for leader using built-in method
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=FRICTION_SCALE
)
# Combine gravity + friction torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower (both arms)
follower_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
# Send action to robot
if follower_action:
robot.send_action(follower_action)
# Get observation from robot (includes camera images)
observation = robot.get_observation()
# Add to dataset if we have a dataset
if dataset is not None:
# Build properly formatted observation frame
obs_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
# Build properly formatted action frame (keep .pos suffix - it matches the feature names)
action_frame = build_dataset_frame(dataset_features, follower_action, prefix="action")
# Combine into single frame
frame = {**obs_frame, **action_frame}
# Add metadata (task is required, timestamp will be auto-calculated by add_frame)
frame["task"] = single_task
dataset.add_frame(frame)
# Display data if requested
if display_data:
log_rerun_data(observation=observation, action=follower_action)
# Maintain loop rate
loop_duration = time.perf_counter() - loop_start
sleep_time = dt - loop_duration
if sleep_time > 0:
time.sleep(sleep_time)
def main():
"""Main recording loop with gravity compensation."""
print("=" * 70)
print("OpenArms Dataset Recording with Compensation")
print("=" * 70)
# Create camera configurations (3 cameras: left wrist, right wrist, base)
# Using actual device paths found by lerobot-find-cameras opencv
camera_config = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video0", width=640, height=480, fps=FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video7", width=640, height=480, fps=FPS),
}
# Configure follower robot with cameras
follower_config = OpenArmsFollowerConfig(
port_left="can2",
port_right="can3",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
cameras=camera_config,
)
# Configure leader teleoperator (no cameras needed)
leader_config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
# Initialize robot and teleoperator
print("\nInitializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
# Connect devices
print("Connecting and calibrating...")
follower.connect(calibrate=True)
leader.connect(calibrate=True)
# Verify URDF is loaded for gravity compensation
if leader.pin_robot is None:
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
# Configure the dataset features
# For actions, we only want to record positions (not velocity or torque)
action_features_hw = {}
for key, value in follower.action_features.items():
if key.endswith(".pos"):
action_features_hw[key] = value
action_features = hw_to_dataset_features(action_features_hw, "action")
obs_features = hw_to_dataset_features(follower.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
print("\nCreating dataset...")
repo_id = "<hf_username>/<dataset_repo_id>" # TODO: Replace with your Hugging Face repo
# Check if dataset already exists and prompt user
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
while dataset_path.exists():
print(f"\nDataset already exists at: {dataset_path}")
print("\nOptions:")
print(" 1. Overwrite existing dataset")
print(" 2. Use a different name")
print(" 3. Abort")
choice = input("\nEnter your choice (1/2/3): ").strip()
if choice == '1':
print(f"Removing existing dataset...")
shutil.rmtree(dataset_path)
print("✓ Existing dataset removed")
break
elif choice == '2':
print("\nCurrent repo_id:", repo_id)
new_repo_id = input("Enter new repo_id (format: <username>/<dataset_name>): ").strip()
if new_repo_id and '/' in new_repo_id:
repo_id = new_repo_id
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id
print(f"✓ Using new repo_id: {repo_id}")
# Loop will continue if this new path also exists
else:
print("Invalid repo_id format. Please use format: <username>/<dataset_name>")
elif choice == '3':
print("Aborting. Please remove the existing dataset manually or restart with a different repo_id.")
follower.disconnect()
leader.disconnect()
return
else:
print("Invalid choice. Please enter 1, 2, or 3.")
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=FPS,
features=dataset_features,
robot_type=follower.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize keyboard listener and visualization
_, events = init_keyboard_listener()
init_rerun(session_name="openarms_recording")
# Enable motors on both leader arms for gravity compensation
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print("\n" + "=" * 70)
print(f"Recording {NUM_EPISODES} episodes")
print(f"Task: {TASK_DESCRIPTION}")
print("=" * 70)
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
print("\nKeyboard controls:")
print(" - Press 'q' to stop recording")
print(" - Press 'r' to re-record current episode")
print("=" * 70)
episode_idx = 0
try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Record episode with compensation active
record_loop_with_compensation(
robot=follower,
leader=leader,
events=events,
fps=FPS,
dataset=dataset,
dataset_features=dataset_features,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop_with_compensation(
robot=follower,
leader=leader,
events=events,
fps=FPS,
dataset=None, # Don't save reset period
dataset_features=dataset_features,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
# Handle re-recording
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Only save episode if frames were recorded
if dataset.episode_buffer is not None and dataset.episode_buffer["size"] > 0:
dataset.save_episode()
episode_idx += 1
else:
log_say("No frames recorded, skipping episode save")
# Clear the empty buffer
dataset.episode_buffer = None
except KeyboardInterrupt:
print("\n\nStopping recording...")
finally:
# Clean up
log_say("Stop recording")
try:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
print("✓ Shutdown complete")
except Exception as e:
print(f"Shutdown error: {e}")
# Upload dataset
print("\nUploading dataset to Hugging Face Hub...")
try:
dataset.push_to_hub()
print("✓ Dataset uploaded successfully")
except Exception as e:
print(f"Warning: Failed to upload dataset: {e}")
print("You can manually upload later using: dataset.push_to_hub()")
print("✓ Recording complete!")
if __name__ == "__main__":
main()
-166
View File
@@ -1,166 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OpenArms Dataset Replay Example
Replays position actions from a recorded dataset on an OpenArms follower robot.
Only position commands (ending with .pos) are replayed, not velocity or torque.
Example usage:
python examples/openarms/replay.py
"""
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
# Configuration
EPISODE_IDX = 0
DATASET_REPO_ID = "lerobot-data-collection/replay-this-2025-11-02-17-58" # TODO: Replace with your dataset
DATASET_ROOT = None # Use default cache location, or specify custom path
# Robot configuration - adjust these to match your setup
ROBOT_CONFIG = OpenArmsFollowerConfig(
port_left="can2", # CAN interface for left arm
port_right="can3", # CAN interface for right arm
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0, # Safety limit: max degrees to move per step
)
def main():
"""Main replay function."""
print("=" * 70)
print("OpenArms Dataset Replay")
print("=" * 70)
print(f"\nDataset: {DATASET_REPO_ID}")
print(f"Episode: {EPISODE_IDX}")
print(f"Robot: {ROBOT_CONFIG.id}")
print(f" Left arm: {ROBOT_CONFIG.port_left}")
print(f" Right arm: {ROBOT_CONFIG.port_right}")
print("\n" + "=" * 70)
# Initialize the robot
print("\n[1/3] Initializing robot...")
robot = OpenArmsFollower(ROBOT_CONFIG)
# Load the dataset
print(f"\n[2/3] Loading dataset '{DATASET_REPO_ID}'...")
dataset = LeRobotDataset(
DATASET_REPO_ID,
root=DATASET_ROOT,
episodes=[EPISODE_IDX]
)
# Filter dataset to only include frames from the specified episode
# (required for dataset V3.0 where episodes are chunked)
episode_frames = dataset.hf_dataset.filter(
lambda x: x["episode_index"] == EPISODE_IDX
)
if len(episode_frames) == 0:
raise ValueError(
f"No frames found for episode {EPISODE_IDX} in dataset {DATASET_REPO_ID}"
)
print(f" Found {len(episode_frames)} frames in episode {EPISODE_IDX}")
# Extract action features from dataset
action_features = dataset.features.get(ACTION, {})
action_names = action_features.get("names", [])
# Filter to only position actions (ending with .pos)
position_action_names = [name for name in action_names if name.endswith(".pos")]
if not position_action_names:
raise ValueError(
f"No position actions found in dataset. Action names: {action_names}"
)
print(f" Found {len(position_action_names)} position actions to replay")
print(f" Actions: {', '.join(position_action_names[:5])}{'...' if len(position_action_names) > 5 else ''}")
# Select only action columns from dataset
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
print(f"\n[3/3] Connecting to robot...")
robot.connect(calibrate=False) # Skip calibration for replay
if not robot.is_connected:
raise RuntimeError("Robot failed to connect!")
print("\n" + "=" * 70)
print("Ready to replay!")
print("=" * 70)
print("\nThe robot will replay the recorded positions.")
print("Press Ctrl+C to stop at any time.\n")
input("Press ENTER to start replaying...")
# Replay loop
log_say(f"Replaying episode {EPISODE_IDX}", blocking=True)
try:
for idx in range(len(episode_frames)):
loop_start = time.perf_counter()
# Extract action array from dataset
action_array = actions[idx][ACTION]
# Build action dictionary, but only include position actions
action = {}
for i, name in enumerate(action_names):
# Only include position actions (ending with .pos)
if name.endswith(".pos"):
action[name] = float(action_array[i])
# Send action to robot
robot.send_action(action)
# Maintain replay rate (use dataset fps)
loop_duration = time.perf_counter() - loop_start
dt_s = 1.0 / dataset.fps - loop_duration
busy_wait(dt_s)
# Progress indicator every 100 frames
if (idx + 1) % 100 == 0:
progress = (idx + 1) / len(episode_frames) * 100
print(f"Progress: {idx + 1}/{len(episode_frames)} frames ({progress:.1f}%)")
print(f"\n✓ Successfully replayed {len(episode_frames)} frames")
log_say("Replay complete", blocking=True)
except KeyboardInterrupt:
print("\n\nReplay interrupted by user")
finally:
# Disconnect robot
print("\nDisconnecting robot...")
robot.disconnect()
print("✓ Replay complete!")
if __name__ == "__main__":
main()
-73
View File
@@ -1,73 +0,0 @@
#!/bin/bash
# Setup all OpenArms CAN interfaces with CAN FD
set -e
echo "=========================================="
echo "OpenArms CAN FD Interface Setup"
echo "=========================================="
echo ""
echo "Mode: CAN FD"
echo " - Nominal bitrate: 1 Mbps"
echo " - Data bitrate: 5 Mbps"
echo ""
echo "Configuring interfaces can0, can1, can2, can3..."
echo ""
# Configure each CAN interface with CAN FD
for i in 0 1 2 3; do
interface="can$i"
# Check if interface exists
if ! ip link show "$interface" &> /dev/null; then
echo "$interface: Not found, skipping"
continue
fi
# Bring down interface
sudo ip link set "$interface" down 2>/dev/null
# Configure CAN FD mode
sudo ip link set "$interface" type can \
bitrate 1000000 \
dbitrate 5000000 \
fd on
# Bring up interface
sudo ip link set "$interface" up
# Verify configuration
if ip link show "$interface" | grep -q "UP"; then
echo "$interface: Configured and UP"
else
echo "$interface: Failed to bring UP"
fi
done
echo ""
echo "=========================================="
echo "Verification"
echo "=========================================="
echo ""
# Show detailed status for each interface
for i in 0 1 2 3; do
interface="can$i"
if ip link show "$interface" &> /dev/null; then
echo "$interface:"
# Show key parameters
ip -d link show "$interface" | grep -E "can|state|bitrate|dbitrate" | head -3
echo ""
fi
done
echo "=========================================="
echo "Setup Complete!"
echo "=========================================="
echo ""
echo "All interfaces configured for CAN FD mode"
echo ""
echo "Next steps:"
echo " 1. Test motors: python debug_can_communication.py"
echo " 2. Run teleoperation: python examples/openarms/teleop.py"
echo ""
-148
View File
@@ -1,148 +0,0 @@
"""
OpenArms Teleoperation Example - Full Dual Arms
This script demonstrates teleoperation of OpenArms follower robot using an OpenArms leader arm.
It first calibrates both devices, then enters a teleoperation loop for both arms.
"""
import time
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
follower_config = OpenArmsFollowerConfig(
port_left="can2", # CAN interface for follower left arm
port_right="can3", # CAN interface for follower right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=5.0, # Safety limit
)
leader_config = OpenArmsLeaderConfig(
port_left="can0", # CAN interface for leader left arm
port_right="can1", # CAN interface for leader right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_leader",
manual_control=True, # Enable manual control (torque disabled)
)
print("=" * 60)
print("OpenArms Teleoperation - Full Dual Arms")
print("=" * 60)
# Initialize devices
print("\n[1/4] Initializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
# Connect and calibrate follower
print("\n[2/4] Connecting and calibrating follower robot...")
print("Note: If you have existing calibration, just press ENTER to use it.")
follower.connect(calibrate=True)
# Connect and calibrate leader
print("\n[3/4] Connecting and calibrating leader arm...")
print("Note: The leader arm will have torque disabled for manual control.")
leader.connect(calibrate=True)
# Wait for user to be ready
print("\n[4/4] Ready for teleoperation!")
print("\nBoth arms will be controlled (16 motors total):")
print(" RIGHT ARM: joints 1-7 + gripper")
print(" LEFT ARM: joints 1-7 + gripper")
print("\nPress ENTER to start teleoperation...")
input()
print("\nTeleoperation started! Move both leader arms.")
print("Press Ctrl+C to stop.\n")
# All joints for both arms (16 motors total)
all_joints = [
# Right arm
"right_joint_1",
"right_joint_2",
"right_joint_3",
"right_joint_4",
"right_joint_5",
"right_joint_6",
"right_joint_7",
"right_gripper",
# Left arm
"left_joint_1",
"left_joint_2",
"left_joint_3",
"left_joint_4",
"left_joint_5",
"left_joint_6",
"left_joint_7",
"left_gripper",
]
# Performance monitoring
loop_times = []
start_time = time.perf_counter()
last_print_time = start_time
try:
while True:
loop_start = time.perf_counter()
# Get action from leader
leader_action = leader.get_action()
# Filter to only position data for all joints (both arms)
joint_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
joint_action[pos_key] = leader_action[pos_key]
# Send action to follower (both arms)
if joint_action:
follower.send_action(joint_action)
# Measure loop time
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Print stats every 2 seconds
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
min_time = min(loop_times)
max_time = max(loop_times)
max_hz = 1.0 / min_time if min_time > 0 else 0
min_hz = 1.0 / max_time if max_time > 0 else 0
print(f"[Hz Stats] Avg: {current_hz:.1f} Hz | "
f"Range: {min_hz:.1f}-{max_hz:.1f} Hz | "
f"Avg loop time: {avg_time*1000:.1f} ms")
# Reset for next measurement window
loop_times = []
last_print_time = loop_end
except KeyboardInterrupt:
print("\n\nStopping teleoperation...")
finally:
# Disconnect devices
print("Disconnecting devices...")
try:
follower.disconnect()
except Exception as e:
print(f"Error disconnecting follower: {e}")
try:
leader.disconnect()
except Exception as e:
print(f"Error disconnecting leader: {e}")
print("Done!")
-197
View File
@@ -1,197 +0,0 @@
"""
OpenArms Mini Teleoperation Example
This script demonstrates teleoperation of an OpenArms follower robot using
an OpenArms Mini leader (Feetech-based) with dual arms (16 motors total).
The OpenArms Mini has:
- Right arm: 8 motors (joint_1 to joint_7 + gripper)
- Left arm: 8 motors (joint_1 to joint_7 + gripper)
Note on gripper normalization:
- OpenArms Mini gripper: 0-100 scale (0=closed, 100=open)
- OpenArms follower gripper: degrees (0=closed, -65=open)
- This script automatically converts between the two ranges
"""
import time
import os
import sys
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.teleoperators.openarms_mini.openarms_mini import OpenArmsMini
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
from lerobot.utils.robot_utils import busy_wait
# Target control frequency
TARGET_FPS = 30
# Configure the OpenArms follower (Damiao motors on CAN bus)
follower_config = OpenArmsFollowerConfig(
port_left="can0", # CAN interface for follower left arm
port_right="can1", # CAN interface for follower right arm
can_interface="socketcan", # Linux SocketCAN
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0, # Safety limit (degrees per step)
)
# Configure the OpenArms Mini leader (Feetech motors on serial)
leader_config = OpenArmsMiniConfig(
port_right="/dev/ttyACM0", # Serial port for right arm
port_left="/dev/ttyACM1", # Serial port for left arm
id="openarms_mini",
use_degrees=True,
)
print("OpenArms Mini → OpenArms Follower Teleoperation")
# Initialize devices
follower = OpenArmsFollower(follower_config)
leader = OpenArmsMini(leader_config)
# Connect and calibrate follower
print("Note: If you have existing calibration, just press ENTER to use it.")
follower.connect(calibrate=True)
# Connect and calibrate leader
print("Note: The leader arms will have torque disabled for manual control.")
leader.connect(calibrate=True)
print("\nPress ENTER to start teleoperation...")
input()
print("Press Ctrl+C to stop.\n")
# All joints for both arms (16 motors total)
all_joints = [
# Right arm
"right_joint_1",
"right_joint_2",
"right_joint_3",
"right_joint_4",
"right_joint_5",
"right_joint_6",
"right_joint_7",
"right_gripper",
# Left arm
"left_joint_1",
"left_joint_2",
"left_joint_3",
"left_joint_4",
"left_joint_5",
"left_joint_6",
"left_joint_7",
"left_gripper",
]
# Performance monitoring
loop_times = []
avg_loop_time = 0.0
min_loop_time = float('inf')
max_loop_time = 0.0
stats_update_interval = 1.0 # Update stats every 1 second
last_stats_update = time.perf_counter()
SWAPPED_JOINTS = {
"right_joint_6": "right_joint_7",
"right_joint_7": "right_joint_6",
"left_joint_6": "left_joint_7",
"left_joint_7": "left_joint_6",
}
try:
while True:
loop_start = time.perf_counter()
# Get actions and observations
leader_action = leader.get_action()
follower_obs = follower.get_observation()
joint_action = {}
for joint in all_joints:
leader_key = f"{joint}.pos"
# Determine which follower joint this leader joint controls
follower_joint = SWAPPED_JOINTS.get(joint, joint)
follower_key = f"{follower_joint}.pos"
# Get leader position (default 0 if missing)
pos = leader_action.get(leader_key, 0.0)
# Convert gripper values: Mini uses 0-100, OpenArms uses 0 to -65 degrees
if "gripper" in joint:
# Map 0-100 (Mini) to 0 to -65 (OpenArms)
# 0 (closed) -> 0°, 100 (open) -> -65°
pos = (pos / 100.0) * -65.0
# Store in action dict for follower
joint_action[follower_key] = pos
follower.send_action(joint_action)
# Loop timing
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
# Update stats periodically
current_time = time.perf_counter()
if current_time - last_stats_update >= stats_update_interval:
if loop_times:
avg_loop_time = sum(loop_times) / len(loop_times)
min_loop_time = min(loop_times)
max_loop_time = max(loop_times)
loop_times = []
last_stats_update = current_time
# Display everything
sys.stdout.write("\033[H\033[J") # Clear screen
# Show timing stats at the top
if avg_loop_time > 0:
avg_hz = 1.0 / avg_loop_time
min_hz = 1.0 / max_loop_time if max_loop_time > 0 else 0
max_hz = 1.0 / min_loop_time if min_loop_time > 0 and min_loop_time < float('inf') else 0
print(f"[Performance] Target: {TARGET_FPS} Hz | Avg: {avg_hz:.1f} Hz | Range: {min_hz:.1f}-{max_hz:.1f} Hz | Loop: {avg_loop_time*1000:.1f} ms\n")
else:
print(f"[Performance] Target: {TARGET_FPS} Hz | Measuring...\n")
# Show joint positions
print(f"{'Joint':<20} {'Leader':>15} {'Follower':>15}")
print(f"{'':20} {'(0-100/deg)':>15} {'(deg)':>15}")
print("-" * 52)
for joint in all_joints:
leader_key = f"{joint}.pos"
follower_joint = SWAPPED_JOINTS.get(joint, joint)
follower_key = f"{follower_joint}.pos"
leader_pos = leader_action.get(leader_key, 0.0)
follower_pos = follower_obs.get(follower_key, 0.0)
print(f"{joint:<20} {leader_pos:>15.2f} {follower_pos:>15.2f}")
# Smart sleep to maintain target FPS
dt_s = time.perf_counter() - loop_start
busy_wait(max(0, 1.0 / TARGET_FPS - dt_s))
except KeyboardInterrupt:
print("\n\nStopping teleoperation...")
finally:
# Disconnect devices
print("Disconnecting devices...")
try:
follower.disconnect()
except Exception as e:
print(f"Error disconnecting follower: {e}")
try:
leader.disconnect()
except Exception as e:
print(f"Error disconnecting leader: {e}")
print("Done!")
@@ -1,202 +0,0 @@
"""
OpenArms Teleoperation with Gravity + Friction Compensation
Leader arms (both LEFT and RIGHT): Gravity + Friction compensation (weightless, easy to move)
Follower arms (both LEFT and RIGHT): Mirror leader movements
Uses the URDF file from the lerobot repository.
"""
import time
import numpy as np
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
# Friction compensation scale factor (1.0 = full, 0.3 = 30% for stability)
FRICTION_SCALE = 1.0
def main():
"""Main teleoperation loop with gravity compensation"""
print("=" * 70)
print("OpenArms Teleoperation with Gravity Compensation")
print("=" * 70)
# Configuration
follower_config = OpenArmsFollowerConfig(
port_left="can2",
port_right="can3",
can_interface="socketcan",
id="openarms_follower",
disable_torque_on_disconnect=True,
max_relative_target=10.0,
)
leader_config = OpenArmsLeaderConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_leader",
manual_control=False, # Enable torque control for gravity compensation
)
# Initialize and connect
print("\nInitializing devices...")
follower = OpenArmsFollower(follower_config)
leader = OpenArmsLeader(leader_config)
follower.connect()
leader.connect()
# URDF is automatically loaded in the leader constructor
if leader.pin_robot is None:
raise RuntimeError("URDF model not loaded on leader. Gravity compensation not available.")
print("\nLeader BOTH arms: Gravity + Friction comp | Follower BOTH arms: Teleop")
print("Press ENTER to start...")
input()
# Enable motors on both leader arms for gravity compensation
leader.bus_right.enable_torque()
leader.bus_left.enable_torque()
time.sleep(0.1)
print("Press Ctrl+C to stop\n")
# Main control loop
loop_times = []
last_print_time = time.perf_counter()
# All joints (both arms)
all_joints = []
for motor in leader.bus_right.motors:
all_joints.append(f"right_{motor}")
for motor in leader.bus_left.motors:
all_joints.append(f"left_{motor}")
try:
while True:
loop_start = time.perf_counter()
# Get leader state
leader_action = leader.get_action()
# Extract positions and velocities in degrees
leader_positions_deg = {}
leader_velocities_deg_per_sec = {}
for motor in leader.bus_right.motors:
pos_key = f"right_{motor}.pos"
vel_key = f"right_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"right_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"right_{motor}"] = leader_action[vel_key]
for motor in leader.bus_left.motors:
pos_key = f"left_{motor}.pos"
vel_key = f"left_{motor}.vel"
if pos_key in leader_action:
leader_positions_deg[f"left_{motor}"] = leader_action[pos_key]
if vel_key in leader_action:
leader_velocities_deg_per_sec[f"left_{motor}"] = leader_action[vel_key]
# Calculate gravity torques for leader using built-in method
leader_positions_rad = {k: np.deg2rad(v) for k, v in leader_positions_deg.items()}
leader_gravity_torques_nm = leader._gravity_from_q(leader_positions_rad)
# Calculate friction torques for leader using built-in method
leader_velocities_rad_per_sec = {k: np.deg2rad(v) for k, v in leader_velocities_deg_per_sec.items()}
leader_friction_torques_nm = leader._friction_from_velocity(
leader_velocities_rad_per_sec,
friction_scale=FRICTION_SCALE
)
# Combine gravity + friction torques
leader_total_torques_nm = {}
for motor_name in leader_gravity_torques_nm:
gravity = leader_gravity_torques_nm.get(motor_name, 0.0)
friction = leader_friction_torques_nm.get(motor_name, 0.0)
leader_total_torques_nm[motor_name] = gravity + friction
# Apply gravity + friction compensation to leader RIGHT arm (all joints including gripper)
for motor in leader.bus_right.motors:
full_name = f"right_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_right._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Apply gravity + friction compensation to leader LEFT arm (all joints including gripper)
for motor in leader.bus_left.motors:
full_name = f"left_{motor}"
position = leader_positions_deg.get(full_name, 0.0)
torque = leader_total_torques_nm.get(full_name, 0.0)
# Get damping gain for stability
kd = leader.get_damping_kd(motor)
leader.bus_left._mit_control(
motor=motor,
kp=0.0,
kd=kd, # Add damping for stability
position_degrees=position,
velocity_deg_per_sec=0.0,
torque=torque,
)
# Send leader positions to follower (both arms)
follower_action = {}
for joint in all_joints:
pos_key = f"{joint}.pos"
if pos_key in leader_action:
follower_action[pos_key] = leader_action[pos_key]
if follower_action:
follower.send_action(follower_action)
# Performance monitoring
loop_end = time.perf_counter()
loop_time = loop_end - loop_start
loop_times.append(loop_time)
if loop_end - last_print_time >= 2.0:
if loop_times:
avg_time = sum(loop_times) / len(loop_times)
current_hz = 1.0 / avg_time if avg_time > 0 else 0
print(f"{current_hz:.1f} Hz ({avg_time*1000:.1f} ms)")
loop_times = []
last_print_time = loop_end
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
try:
leader.bus_right.disable_torque()
leader.bus_left.disable_torque()
time.sleep(0.1)
leader.disconnect()
follower.disconnect()
print("✓ Shutdown complete")
except Exception as e:
print(f"Shutdown error: {e}")
if __name__ == "__main__":
main()
-76
View File
@@ -1,76 +0,0 @@
import time
import math
import numpy as np
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
def main():
cfg = OpenArmsFollowerConfig(
port_left="can0",
port_right="can1",
can_interface="socketcan",
id="openarms_test",
manual_control=True, # direct position control
)
print('connecting...')
rob = OpenArmsFollower(cfg)
rob.connect(calibrate=True)
# disable left torque fully — keep it still
rob.bus_left.disable_torque()
# desired angular sweep = 1/4 of current joint range
sweep_deg = 20.0 # tweak if you want bigger movement
# frequency of movement
hz = 100.0
dt = 1.0 / hz
move_time = 1.0 # seconds per joint
print('starting rightarm joint test…')
print('support the arm and keep clear')
time.sleep(1.0)
# iterate motors except gripper
for motor in rob.bus_right.motors:
if motor == 'gripper':
continue
print(f'testing {motor} on right arm...')
start = time.time()
# read current position as center
obs = rob.get_action()
key = f'right_{motor}.pos'
center = obs.get(key, 0.0)
t = 0.0
while time.time() - start < move_time:
offset = sweep_deg * math.sin(2 * math.pi * t)
pos_cmd = center + offset
rob.bus_right._mit_control(
motor=motor,
kp=3.0, # some stiffness so it tracks well
kd=0.2,
position_degrees=pos_cmd,
velocity_deg_per_sec=0.0,
torque=0.0
)
t += dt
time.sleep(dt)
print(f'done {motor}')
print('\nall rightarm joints tested')
print('disabling torque…')
rob.bus_right.disable_torque()
rob.disconnect()
if __name__ == '__main__':
main()
-745
View File
@@ -1,745 +0,0 @@
body {
margin: 0;
padding: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: #f5f5f5;
}
main {
min-height: 100vh;
padding: 2rem;
}
header {
text-align: center;
margin-bottom: 2rem;
}
h1 {
font-size: 2rem;
font-weight: 600;
color: #333;
margin: 0;
}
h2 {
font-size: 1.25rem;
font-weight: 600;
color: #333;
margin: 0 0 1rem 0;
}
h3 {
font-size: 0.875rem;
font-weight: 600;
color: #666;
margin: 0 0 0.5rem 0;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.container {
max-width: 1920px;
margin: 0 auto;
display: grid;
grid-template-columns: minmax(500px, 600px) 1fr;
gap: 2rem;
align-items: start;
}
/* Left column container */
.left-column {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
/* Right column container */
.right-column {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
/* Responsive: Stack on smaller screens */
@media (max-width: 1200px) {
.container {
grid-template-columns: 1fr;
}
}
.panel {
background: white;
border-radius: 8px;
padding: 1.5rem;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
.config-panel {
border: 2px solid #e5e7eb;
}
.config-header {
display: flex;
justify-content: space-between;
align-items: center;
cursor: pointer;
user-select: none;
padding: 0.5rem 0;
}
.config-header:hover {
opacity: 0.7;
}
.toggle-icon {
font-size: 1rem;
color: #6b7280;
transition: transform 0.2s;
}
.config-content {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.robot-setup {
margin-bottom: 0.5rem;
}
.robot-status {
display: flex;
align-items: center;
justify-content: space-between;
padding: 1rem;
border-radius: 6px;
font-weight: 500;
gap: 1rem;
}
.robot-status.ready {
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
color: #065f46;
border: 1px solid #10b981;
}
.robot-status.not-ready {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
color: #92400e;
border: 1px solid #f59e0b;
}
.btn-setup {
background: #10b981;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-setup:hover:not(:disabled) {
background: #059669;
}
.btn-setup:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-zero {
background: #8b5cf6;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-zero:hover:not(:disabled) {
background: #7c3aed;
}
.btn-zero:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.zero-position-section {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.btn-zero-large {
width: 100%;
background: #8b5cf6;
color: white;
border: none;
padding: 0.875rem 1.5rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(139, 92, 246, 0.2);
}
.btn-zero-large:hover:not(:disabled) {
background: #7c3aed;
box-shadow: 0 4px 8px rgba(139, 92, 246, 0.3);
transform: translateY(-1px);
}
.btn-zero-large:disabled {
background: #d1d5db;
cursor: not-allowed;
box-shadow: none;
transform: none;
}
.delete-episode-section {
margin-top: 1rem;
padding-top: 1rem;
border-top: 1px solid #e5e7eb;
}
.btn-delete {
width: 100%;
background: #ef4444;
color: white;
border: none;
padding: 0.875rem 1.5rem;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 2px 4px rgba(239, 68, 68, 0.2);
}
.btn-delete:hover:not(:disabled) {
background: #dc2626;
box-shadow: 0 4px 8px rgba(239, 68, 68, 0.3);
transform: translateY(-1px);
}
.btn-delete:disabled {
background: #d1d5db;
cursor: not-allowed;
box-shadow: none;
transform: none;
}
.delete-info {
margin-top: 0.5rem;
font-size: 0.875rem;
color: #666;
text-align: center;
font-style: italic;
}
.btn-disconnect {
background: #ef4444;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 4px;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-disconnect:hover {
background: #dc2626;
}
.btn-refresh {
background: #3b82f6;
color: white;
border: none;
padding: 0.4rem 0.8rem;
border-radius: 4px;
font-size: 0.75rem;
font-weight: 500;
cursor: pointer;
transition: background 0.2s;
}
.btn-refresh:hover:not(:disabled) {
background: #2563eb;
}
.btn-refresh:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.control-panel {
border: 2px solid #10b981;
}
.status-banner {
display: flex;
align-items: center;
gap: 1rem;
padding: 1rem 1.5rem;
border-radius: 6px;
margin-bottom: 1.5rem;
font-weight: 500;
font-size: 0.95rem;
}
.status-banner.initializing {
background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%);
color: #1e40af;
border-left: 4px solid #3b82f6;
}
.status-banner.encoding {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
color: #92400e;
border-left: 4px solid #f59e0b;
}
.status-banner.uploading {
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%);
color: #3730a3;
border-left: 4px solid #6366f1;
}
.status-banner.success {
background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%);
color: #065f46;
border-left: 4px solid #10b981;
}
.status-banner.warning {
background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%);
color: #991b1b;
border-left: 4px solid #ef4444;
}
.spinner {
width: 20px;
height: 20px;
border: 3px solid rgba(0, 0, 0, 0.1);
border-top-color: currentColor;
border-radius: 50%;
animation: spin 0.8s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.control-horizontal {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.control-left {
display: flex;
flex-direction: column;
gap: 1rem;
}
.control-right {
display: flex;
align-items: center;
justify-content: center;
}
.input-group {
display: flex;
gap: 0.5rem;
margin-bottom: 0;
}
input[type="text"] {
flex: 1;
padding: 0.75rem;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 1rem;
}
input[type="text"]:disabled {
background: #f5f5f5;
cursor: not-allowed;
}
input[type="text"]:focus {
outline: none;
border-color: #10b981;
}
button {
padding: 0.75rem 1.5rem;
border: none;
border-radius: 4px;
font-size: 1rem;
font-weight: 500;
cursor: pointer;
transition: all 0.2s;
}
.btn-set-task {
background: #3b82f6;
color: white;
min-width: 120px;
}
.btn-set-task:hover:not(:disabled) {
background: #2563eb;
}
.btn-set-task:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-start {
background: #10b981;
color: white;
}
.btn-start:hover:not(:disabled) {
background: #059669;
}
.btn-start:disabled {
background: #d1d5db;
cursor: not-allowed;
}
.btn-stop {
background: #ef4444;
color: white;
}
.btn-stop:hover {
background: #dc2626;
}
.btn-reset {
padding: 0.5rem 1rem;
background: #6b7280;
color: white;
font-size: 0.875rem;
}
.btn-reset:hover {
background: #4b5563;
}
.status {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 1rem;
border-radius: 4px;
margin-bottom: 1rem;
}
.status.recording {
background: #fee2e2;
color: #991b1b;
}
.status.recording.recording-active {
display: flex;
flex-direction: column;
gap: 1rem;
background: #dc2626;
color: white;
padding: 1.5rem;
border: 4px solid #991b1b;
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.4);
font-weight: 700;
font-size: 1rem;
}
.status.recording.recording-active .indicator {
width: 20px;
height: 20px;
background: #fef2f2;
animation: pulse-strong 1s ease-in-out infinite;
}
@keyframes pulse-strong {
0%, 100% {
opacity: 1;
transform: scale(1);
}
50% {
opacity: 0.7;
transform: scale(1.1);
}
}
.status.recording.recording-active .time-display {
display: flex;
flex-direction: column;
gap: 0.5rem;
font-size: 1.5rem;
font-weight: 700;
color: white;
}
.fps-display {
font-size: 1rem;
font-weight: 500;
opacity: 0.95;
}
.fps-warning {
color: #fef2f2;
animation: pulse-warning 1s ease-in-out infinite;
}
@keyframes pulse-warning {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.status.recording.recording-active .btn-stop {
align-self: stretch;
}
.ramp-up-countdown {
display: flex;
justify-content: center;
margin-bottom: 1rem;
}
.countdown-box {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 2rem 3rem;
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
border: 4px solid #f59e0b;
border-radius: 16px;
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
min-width: 280px;
animation: pulse-warm 1.5s ease-in-out infinite;
}
@keyframes pulse-warm {
0%, 100% {
box-shadow: 0 6px 20px rgba(245, 158, 11, 0.4);
}
50% {
box-shadow: 0 6px 25px rgba(245, 158, 11, 0.6);
}
}
.countdown-label {
font-size: 1rem;
color: #92400e;
text-transform: uppercase;
letter-spacing: 1.5px;
font-weight: 800;
margin-bottom: 1rem;
text-align: center;
}
.countdown-value {
font-size: 4.5rem;
font-weight: 900;
color: #d97706;
font-family: 'Courier New', monospace;
line-height: 1;
text-shadow: 2px 2px 6px rgba(0, 0, 0, 0.15);
margin-bottom: 0.5rem;
}
.countdown-subtitle {
font-size: 0.875rem;
color: #78350f;
font-weight: 600;
font-style: italic;
text-align: center;
margin-top: 0.5rem;
}
.status.idle {
background: #f3f4f6;
color: #374151;
}
.indicator {
width: 12px;
height: 12px;
border-radius: 50%;
background: #ef4444;
animation: pulse 1.5s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.counter {
display: flex;
flex-direction: column;
align-items: center;
gap: 0.75rem;
padding: 1.5rem;
background: linear-gradient(135deg, #f9fafb 0%, #f3f4f6 100%);
border-radius: 8px;
border: 2px solid #e5e7eb;
min-width: 200px;
}
.counter-label {
font-size: 0.75rem;
color: #6b7280;
text-transform: uppercase;
letter-spacing: 0.5px;
font-weight: 600;
}
.counter-value {
font-size: 3rem;
font-weight: 700;
color: #10b981;
line-height: 1;
}
.time-display {
font-size: 1.5rem;
font-weight: 600;
font-family: 'Courier New', monospace;
}
.error-box {
padding: 1rem;
background: #fee2e2;
color: #991b1b;
border-radius: 4px;
border-left: 4px solid #ef4444;
font-size: 0.875rem;
}
.config-section {
margin-bottom: 1.5rem;
}
.config-section:last-child {
margin-bottom: 0;
}
.config-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 1rem;
}
label {
display: flex;
flex-direction: column;
gap: 0.5rem;
font-size: 0.875rem;
color: #374151;
font-weight: 500;
}
select {
padding: 0.5rem;
border: 1px solid #ddd;
border-radius: 4px;
font-size: 0.875rem;
background: white;
}
select:disabled {
background: #f5f5f5;
cursor: not-allowed;
}
/* Camera Layout */
.camera-layout {
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.camera-base {
width: 100%;
}
.camera-wrist-container {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 1.5rem;
}
.camera-wrist {
width: 100%;
}
.camera {
border: 1px solid #e5e7eb;
border-radius: 4px;
overflow: hidden;
}
.camera h3 {
padding: 0.75rem;
background: #f9fafb;
border-bottom: 1px solid #e5e7eb;
margin: 0;
}
.camera img {
width: 100%;
height: auto;
display: block;
background: #000;
min-height: 300px;
object-fit: cover;
}
.camera-placeholder {
text-align: center;
padding: 4rem 2rem;
background: #f9fafb;
border-radius: 4px;
border: 2px dashed #d1d5db;
}
.camera-placeholder p {
margin: 0.5rem 0;
font-size: 1rem;
color: #6b7280;
}
.camera-placeholder p:first-child {
font-size: 1.25rem;
font-weight: 500;
color: #374151;
}
.hint {
margin-top: 0.5rem;
font-size: 0.75rem;
color: #6b7280;
display: flex;
align-items: center;
gap: 0.5rem;
flex-wrap: wrap;
}
-857
View File
@@ -1,857 +0,0 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import './App.css';
const API_BASE = 'http://localhost:8000/api';
function App() {
// State
const [task, setTask] = useState('');
const [isRecording, setIsRecording] = useState(false);
const [isInitializing, setIsInitializing] = useState(false);
const [isEncoding, setIsEncoding] = useState(false);
const [isUploading, setIsUploading] = useState(false);
const [robotsReady, setRobotsReady] = useState(false);
const [elapsedTime, setElapsedTime] = useState(0);
const [currentFps, setCurrentFps] = useState(0);
const [loopFps, setLoopFps] = useState(0);
const [episodeCount, setEpisodeCount] = useState(0);
const [error, setError] = useState(null);
const [statusMessage, setStatusMessage] = useState('Ready');
const [uploadStatus, setUploadStatus] = useState(null);
const [rampUpRemaining, setRampUpRemaining] = useState(0);
const [movingToZero, setMovingToZero] = useState(false);
const [configExpanded, setConfigExpanded] = useState(false);
const [latestRepoId, setLatestRepoId] = useState(null);
// Configuration
const [config, setConfig] = useState({
leader_type: 'openarms', // 'openarms' or 'openarms_mini'
leader_left: 'can0',
leader_right: 'can1',
follower_left: 'can2',
follower_right: 'can3',
left_wrist: '/dev/video0',
right_wrist: '/dev/video1',
base: '/dev/video4'
});
// Available options
const [availableCameras, setAvailableCameras] = useState([]);
const [availableUsbPorts, setAvailableUsbPorts] = useState([]);
const canInterfaces = ['can0', 'can1', 'can2', 'can3'];
const statusIntervalRef = useRef(null);
const hasInitializedRef = useRef(false);
const loadConfig = () => {
try {
const saved = localStorage.getItem('openarms_config');
if (saved) {
const loadedConfig = JSON.parse(saved);
setConfig(prev => ({ ...prev, ...loadedConfig }));
}
} catch (e) {
console.error('Load config error:', e);
}
};
const saveConfig = (newConfig) => {
try {
localStorage.setItem('openarms_config', JSON.stringify(newConfig || config));
} catch (e) {
console.error('Save config error:', e);
}
};
// Fetch status periodically
const fetchStatus = async () => {
try {
const response = await fetch(`${API_BASE}/status`);
const data = await response.json();
setIsRecording(data.is_recording);
setIsInitializing(data.is_initializing);
setIsEncoding(data.is_encoding);
setIsUploading(data.is_uploading);
setRobotsReady(data.robots_ready);
setElapsedTime(data.elapsed_time);
setCurrentFps(data.current_fps || 0);
setLoopFps(data.loop_fps || 0);
setEpisodeCount(data.episode_count);
setError(data.error);
setStatusMessage(data.status_message || 'Ready');
setUploadStatus(data.upload_status);
setRampUpRemaining(data.ramp_up_remaining || 0);
setMovingToZero(data.moving_to_zero || false);
// Track the latest repo_id from the backend
if (data.latest_repo_id) {
setLatestRepoId(data.latest_repo_id);
}
if (data.config) {
// Only merge server config if we don't have a saved config (first load)
if (!localStorage.getItem('openarms_config')) {
setConfig(prev => {
const merged = { ...data.config, ...prev };
localStorage.setItem('openarms_config', JSON.stringify(merged));
return merged;
});
}
}
} catch (e) {
console.error('Failed to fetch status:', e);
}
};
const setupRobots = async () => {
// Show warning to verify camera positions
const confirmed = window.confirm(
'⚠️ IMPORTANT: Before connecting robots, please verify:\n\n' +
'📹 Check that cameras are correctly positioned:\n' +
' • LEFT wrist camera is actually on the LEFT arm\n' +
' • RIGHT wrist camera is actually on the RIGHT arm\n' +
' • BASE camera is actually the BASE/overhead camera\n\n' +
'Incorrect camera positioning will result in invalid training data!\n\n' +
'Click OK to continue with robot setup, or Cancel to review configuration.'
);
if (!confirmed) {
return; // User cancelled, don't proceed
}
setError(null);
try {
const response = await fetch(`${API_BASE}/robots/setup`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(config)
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to setup robots');
}
await response.json();
saveConfig(config);
} catch (e) {
setError(`Robot setup failed: ${e.message}`);
}
};
// Disconnect robots
const disconnectRobots = async () => {
try {
await fetch(`${API_BASE}/robots/disconnect`, { method: 'POST' });
setRobotsReady(false);
} catch (e) {
console.error('Failed to disconnect robots:', e);
}
};
// Discover cameras
const discoverCameras = async () => {
try {
const response = await fetch(`${API_BASE}/cameras/discover`);
const data = await response.json();
const cameras = data.cameras || [];
setAvailableCameras(cameras);
// Get list of valid camera IDs
const validCameraIds = cameras.map(cam => String(cam.id));
// Auto-fix config if current values are invalid or not set
const updated = { ...config };
let changed = false;
// Auto-fix invalid camera config
if (!config.left_wrist || !validCameraIds.includes(config.left_wrist)) {
if (cameras.length >= 1) {
updated.left_wrist = String(cameras[0].id);
changed = true;
}
}
if (!config.right_wrist || !validCameraIds.includes(config.right_wrist)) {
if (cameras.length >= 2) {
updated.right_wrist = String(cameras[1].id);
changed = true;
}
}
if (!config.base || !validCameraIds.includes(config.base)) {
if (cameras.length >= 3) {
updated.base = String(cameras[2].id);
changed = true;
}
}
if (changed) {
setConfig(updated);
saveConfig(updated);
}
if (cameras.length === 0) {
setError('No cameras detected! Please connect cameras and refresh.');
}
} catch (e) {
console.error('Failed to discover cameras:', e);
setError(`Camera discovery failed: ${e.message}`);
}
};
// Discover USB ports
const discoverUsbPorts = async () => {
try {
const response = await fetch(`${API_BASE}/usb/discover`);
const data = await response.json();
const ports = data.ports || [];
setAvailableUsbPorts(ports);
// Auto-fix config if OpenArms Mini is selected and ports are invalid
if (config.leader_type === 'openarms_mini') {
const updated = { ...config };
let changed = false;
if (ports.length >= 1 && !ports.includes(config.leader_left)) {
updated.leader_left = ports[0];
changed = true;
}
if (ports.length >= 2 && !ports.includes(config.leader_right)) {
updated.leader_right = ports[1];
changed = true;
}
if (changed) {
setConfig(updated);
saveConfig(updated);
}
}
if (ports.length === 0) {
console.warn('No USB ports detected for OpenArms Mini');
}
} catch (e) {
console.error('Failed to discover USB ports:', e);
}
};
// Set task only (for pedal use)
const setTaskOnly = async () => {
if (!task.trim()) {
setError('Please enter a task description');
return;
}
setError(null);
try {
const response = await fetch(`${API_BASE}/recording/set-task`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ task, ...config })
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to set task');
}
const result = await response.json();
setStatusMessage(result.message || `Task set: ${task}`);
saveConfig(config);
// Clear success message after 3 seconds
setTimeout(() => {
if (!isRecording && !isInitializing) {
setStatusMessage('Ready');
}
}, 3000);
} catch (e) {
setError(e.message);
}
};
// Start recording
const startRecording = async () => {
if (!task.trim()) {
setError('Please enter a task description');
return;
}
setError(null);
try {
const response = await fetch(`${API_BASE}/recording/start`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ task, ...config })
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to start recording');
}
await response.json();
saveConfig(config);
} catch (e) {
setError(e.message);
}
};
// Stop recording
const stopRecording = async () => {
try {
const response = await fetch(`${API_BASE}/recording/stop`, {
method: 'POST'
});
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to stop recording');
}
const data = await response.json();
setError(null);
// Update latest repo_id after recording
if (data.dataset_name) {
setLatestRepoId(`lerobot-data-collection/${data.dataset_name}`);
}
} catch (e) {
setError(e.message);
}
};
const deleteLatestEpisode = async () => {
if (!latestRepoId) {
setError('No episode to delete');
return;
}
const confirmed = window.confirm(
`WARNING: This will permanently delete the repository:\n\n${latestRepoId}\n\nThis action cannot be undone. Continue?`
);
if (!confirmed) {
return;
}
try {
const response = await fetch(`${API_BASE}/recording/delete-latest`, { method: 'POST' });
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to delete episode');
}
const data = await response.json();
setLatestRepoId(null);
setEpisodeCount(Math.max(0, episodeCount - 1));
setStatusMessage(`Deleted: ${data.deleted_repo}`);
setTimeout(() => {
if (!isRecording && !isInitializing) {
setStatusMessage('Ready');
}
}, 3000);
} catch (e) {
setError(`Delete failed: ${e.message}`);
}
};
// Reset counter
const resetCounter = async () => {
try {
await fetch(`${API_BASE}/counter/reset`, { method: 'POST' });
setEpisodeCount(0);
} catch (e) {
console.error('Failed to reset counter:', e);
}
};
// Move robot to zero position
const moveToZero = async () => {
setError(null);
try {
const response = await fetch(`${API_BASE}/robots/move-to-zero`, { method: 'POST' });
if (!response.ok) {
const data = await response.json();
throw new Error(data.detail || 'Failed to move to zero position');
}
await response.json();
} catch (e) {
setError(`Move to zero failed: ${e.message}`);
}
};
// Format time as MM:SS
const formatTime = (seconds) => {
const mins = Math.floor(seconds / 60);
const secs = Math.floor(seconds % 60);
return `${mins.toString().padStart(2, '0')}:${secs.toString().padStart(2, '0')}`;
};
// Update config and save
const updateConfig = (key, value) => {
const updated = { ...config, [key]: value };
setConfig(updated);
saveConfig(updated);
};
// Initialize on mount only
useEffect(() => {
// Prevent double-initialization in development
if (hasInitializedRef.current) {
return;
}
hasInitializedRef.current = true;
loadConfig();
discoverCameras();
discoverUsbPorts();
fetchStatus();
statusIntervalRef.current = setInterval(fetchStatus, 1000);
return () => {
if (statusIntervalRef.current) {
clearInterval(statusIntervalRef.current);
}
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []); // Run only once on mount
// Discover USB ports when leader type changes to Mini
useEffect(() => {
if (config.leader_type === 'openarms_mini') {
discoverUsbPorts();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [config.leader_type]);
return (
<main>
<header>
<h1>OpenArms Recording</h1>
</header>
<div className="container">
{/* Left Column: Configuration and Recording Control */}
<div className="left-column">
{/* Configuration Panel */}
<section className="panel config-panel">
<div
className="config-header"
onClick={() => setConfigExpanded(!configExpanded)}
role="button"
tabIndex={0}
onKeyDown={(e) => e.key === 'Enter' && setConfigExpanded(!configExpanded)}
>
<h2> Configuration</h2>
<span className="toggle-icon">{configExpanded ? '▼' : '▶'}</span>
</div>
{configExpanded && (
<div className="config-content">
{/* Robot Setup */}
<div className="config-section">
<h3>🤖 Robot Setup</h3>
<div className="robot-setup">
{robotsReady ? (
<div className="robot-status ready">
<span> Robots Ready - Recording will start instantly</span>
<button onClick={disconnectRobots} className="btn-disconnect">
Disconnect Robots
</button>
</div>
) : (
<div className="robot-status not-ready">
<span> Robots not initialized - Recording will take ~10 seconds</span>
<button
onClick={setupRobots}
disabled={isRecording || isInitializing}
className="btn-setup"
>
🚀 Setup Robots
</button>
</div>
)}
</div>
</div>
{/* Leader Type Selection */}
<div className="config-section">
<h3>🎮 Leader Type</h3>
<div className="config-grid">
<label style={{gridColumn: '1 / -1'}}>
Leader Arm Type
<select
value={config.leader_type}
onChange={(e) => updateConfig('leader_type', e.target.value)}
disabled={isRecording || robotsReady}
>
<option value="openarms">OpenArms (CAN Bus - Damiao Motors)</option>
<option value="openarms_mini">OpenArms Mini (USB - Feetech Motors)</option>
</select>
</label>
</div>
</div>
{/* Leader Interfaces (CAN or USB based on type) */}
<div className="config-section">
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
<h3>
{config.leader_type === 'openarms_mini'
? `Leader Ports (USB/Serial) ${availableUsbPorts.length > 0 ? `(${availableUsbPorts.length} detected)` : ''}`
: 'Leader Interfaces (CAN)'}
</h3>
{config.leader_type === 'openarms_mini' && (
<button
onClick={discoverUsbPorts}
className="btn-refresh"
disabled={isRecording || robotsReady}
>
🔄 Refresh
</button>
)}
</div>
<div className="config-grid">
<label>
Leader Left
<select
value={config.leader_left}
onChange={(e) => updateConfig('leader_left', e.target.value)}
disabled={isRecording || robotsReady}
>
{config.leader_type === 'openarms_mini' ? (
availableUsbPorts.length > 0 ? (
availableUsbPorts.map((port) => (
<option key={port} value={port}>{port}</option>
))
) : (
<option value="">No USB ports detected</option>
)
) : (
canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))
)}
</select>
</label>
<label>
Leader Right
<select
value={config.leader_right}
onChange={(e) => updateConfig('leader_right', e.target.value)}
disabled={isRecording || robotsReady}
>
{config.leader_type === 'openarms_mini' ? (
availableUsbPorts.length > 0 ? (
availableUsbPorts.map((port) => (
<option key={port} value={port}>{port}</option>
))
) : (
<option value="">No USB ports detected</option>
)
) : (
canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))
)}
</select>
</label>
</div>
</div>
{/* Follower CAN Interfaces */}
<div className="config-section">
<h3>Follower Interfaces (CAN)</h3>
<div className="config-grid">
<label>
Follower Left
<select
value={config.follower_left}
onChange={(e) => updateConfig('follower_left', e.target.value)}
disabled={isRecording || robotsReady}
>
{canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))}
</select>
</label>
<label>
Follower Right
<select
value={config.follower_right}
onChange={(e) => updateConfig('follower_right', e.target.value)}
disabled={isRecording || robotsReady}
>
{canInterfaces.map((iface) => (
<option key={iface} value={iface}>{iface}</option>
))}
</select>
</label>
</div>
</div>
{/* Camera Configuration */}
<div className="config-section">
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: '0.5rem' }}>
<h3>Cameras {availableCameras.length > 0 && `(${availableCameras.length} detected)`}</h3>
<button
onClick={discoverCameras}
className="btn-refresh"
disabled={isRecording || robotsReady}
>
🔄 Refresh
</button>
</div>
<div className="config-grid">
<label>
Left Wrist
<select
value={config.left_wrist}
onChange={(e) => updateConfig('left_wrist', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
<label>
Right Wrist
<select
value={config.right_wrist}
onChange={(e) => updateConfig('right_wrist', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
<label>
Base Camera
<select
value={config.base}
onChange={(e) => updateConfig('base', e.target.value)}
disabled={isRecording || robotsReady}
>
{availableCameras.map((cam) => (
<option key={cam.id} value={String(cam.id)}>
{cam.name || `Camera @ ${cam.id}`}
</option>
))}
</select>
</label>
</div>
</div>
</div>
)}
</section>
{/* Control Panel */}
<section className="panel control-panel">
<h2>🎬 Recording Control</h2>
{/* Status Banner - Always show important statuses */}
{isInitializing && (
<div className="status-banner initializing">
<div className="spinner"></div>
<span>{statusMessage}</span>
</div>
)}
{isEncoding && (
<div className="status-banner encoding">
<div className="spinner"></div>
<span>📹 {statusMessage}</span>
</div>
)}
{isUploading && (
<div className="status-banner uploading">
<div className="spinner"></div>
<span> {statusMessage}</span>
</div>
)}
{uploadStatus && !isRecording && !isEncoding && !isUploading && (
<div className={`status-banner ${uploadStatus.startsWith('✓') ? 'success' : 'warning'}`}>
<span>{uploadStatus}</span>
</div>
)}
<div className="control-horizontal">
{/* Task Input and Status */}
<div className="control-left">
<div className="input-group">
<input
type="text"
value={task}
onChange={(e) => setTask(e.target.value)}
placeholder="Task description (e.g., 'pick and place')"
disabled={isRecording || isInitializing || isEncoding || isUploading}
onKeyPress={(e) => {
if (e.key === 'Enter' && robotsReady) {
setTaskOnly();
}
}}
/>
<button
onClick={setTaskOnly}
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
className="btn-set-task"
title={!robotsReady ? 'Please setup robots first' : 'Store task for pedal use (Enter key)'}
>
💾 Set Task
</button>
<button
onClick={startRecording}
disabled={isRecording || isInitializing || isEncoding || isUploading || !robotsReady}
className="btn-start"
title={!robotsReady ? 'Please setup robots first' : ''}
>
{isInitializing
? '⏳ Initializing...'
: isRecording
? '⏺ Recording...'
: robotsReady
? '⏺ Start Recording'
: '⏺ Setup Robots First'}
</button>
</div>
{/* Ramp-up Countdown */}
{isRecording && rampUpRemaining > 0 && (
<div className="ramp-up-countdown">
<div className="countdown-box">
<div className="countdown-label"> WARMING UP - PID RAMP-UP</div>
<div className="countdown-value">{rampUpRemaining.toFixed(1)}s</div>
<div className="countdown-subtitle">Recording will start automatically...</div>
</div>
</div>
)}
{/* Recording Status - Only show after ramp-up */}
{isRecording && rampUpRemaining <= 0 && (
<div className="status recording recording-active">
<div className="indicator"></div>
<div className="time-display">
<span>{formatTime(elapsedTime)}</span>
<span className="fps-display">
Loop: {loopFps.toFixed(1)} Hz
{loopFps > 0 && loopFps < 29 && <span className="fps-warning"> </span>}
</span>
<span className="fps-display">Recording: {currentFps.toFixed(1)} FPS</span>
</div>
<button onClick={stopRecording} className="btn-stop">
Stop
</button>
</div>
)}
</div>
{/* Episode Counter */}
<div className="control-right">
<div className="counter">
<div className="counter-label">Episodes Recorded</div>
<div className="counter-value">{episodeCount}</div>
<button onClick={resetCounter} className="btn-reset">
Reset
</button>
</div>
</div>
</div>
{/* Delete Latest Episode Button */}
{!isRecording && !isInitializing && latestRepoId && (
<div className="delete-episode-section">
<button
onClick={deleteLatestEpisode}
className="btn-delete"
title="Delete the latest recorded episode from HuggingFace Hub"
>
Delete Latest Episode
</button>
<div className="delete-info">Will delete: {latestRepoId}</div>
</div>
)}
{/* Move to Zero Button */}
{robotsReady && !isRecording && !isInitializing && (
<div className="zero-position-section">
<button
onClick={moveToZero}
disabled={movingToZero}
className="btn-zero-large"
title="Move both leader and follower robots to zero position (2s)"
>
{movingToZero ? '⏳ Moving to Zero Position...' : '🎯 Move to Zero Position (Leader + Follower)'}
</button>
</div>
)}
{/* Error Display */}
{error && (
<div className="error-box">
{error}
</div>
)}
</section>
</div>
{/* Right Column: Camera Feeds */}
<div className="right-column">
<section className="panel cameras">
<h2>📹 Camera Views</h2>
{robotsReady || isRecording || isInitializing ? (
<div className="camera-layout">
{/* Base camera - full width */}
<div className="camera camera-base">
<h3>Base Camera</h3>
<img src={`${API_BASE}/camera/stream/base`} alt="Base Camera" />
</div>
{/* Wrist cameras - side by side */}
<div className="camera-wrist-container">
<div className="camera camera-wrist">
<h3>Left Wrist</h3>
<img src={`${API_BASE}/camera/stream/left_wrist`} alt="Left Wrist Camera" />
</div>
<div className="camera camera-wrist">
<h3>Right Wrist</h3>
<img src={`${API_BASE}/camera/stream/right_wrist`} alt="Right Wrist Camera" />
</div>
</div>
</div>
) : (
<div className="camera-placeholder">
<p>📷 Camera feeds will appear when robots are set up</p>
<p className="hint">Click "Setup Robots" above to preview camera feeds</p>
</div>
)}
</section>
</div>
</div>
</main>
);
}
export default App;
-41
View File
@@ -1,41 +0,0 @@
# OpenArms Web Recording Interface
A web interface for recording OpenArms datasets.
## Installation
```bash
cd examples/openarms_web_interface
npm install
```
## Usage
**Start everything with one command:**
```bash
./launch.sh
```
This will:
- Start the FastAPI backend on port 8000
- Start the React frontend on port 5173
- Show live logs from both services
Then open your browser to: **http://localhost:5173**
**Stop with:** `Ctrl+C`
---
## Workflow
1. **Configure CAN interfaces** and **camera paths** in the dropdowns
2. Click **"Setup Robots"** to initialize (once at start)
3. Enter a **task description**
4. Click **"Start Recording"** to begin an episode
5. Click **"Stop Recording"** when done
6. Dataset is automatically encoded and uploaded to HuggingFace Hub as **private**
7. Repeat steps 3-6 for more episodes (no need to re-setup robots!)
---
@@ -1,12 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>OpenArms Recording Interface</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/main.jsx"></script>
</body>
</html>
-142
View File
@@ -1,142 +0,0 @@
#!/bin/bash
# OpenArms Web Interface Launcher
# Starts Rerun viewer, FastAPI backend, and React frontend
set -e
# Colors for output
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
# Get script directory
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
echo -e "${BLUE}╔════════════════════════════════════════╗${NC}"
echo -e "${BLUE}║ OpenArms Web Recording Interface ║${NC}"
echo -e "${BLUE}╚════════════════════════════════════════╝${NC}"
echo ""
# Function to cleanup on exit
cleanup() {
echo ""
echo -e "${YELLOW}Shutting down services...${NC}"
# Kill all child processes
pkill -P $$ 2>/dev/null || true
# Kill specific services by port
lsof -ti:8000 | xargs kill -9 2>/dev/null || true # Backend
lsof -ti:5173 | xargs kill -9 2>/dev/null || true # Frontend
lsof -ti:9876 | xargs kill -9 2>/dev/null || true # Rerun (if spawned)
echo -e "${GREEN}✓ Services stopped${NC}"
exit 0
}
# Register cleanup on script exit
trap cleanup EXIT INT TERM
# Check if required commands exist
command -v rerun >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'rerun' not found. Please install: pip install rerun-sdk${NC}"
exit 1
}
command -v python >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'python' not found${NC}"
exit 1
}
command -v npm >/dev/null 2>&1 || {
echo -e "${RED}✗ Error: 'npm' not found${NC}"
exit 1
}
# Check if node_modules exists
if [ ! -d "node_modules" ]; then
echo -e "${YELLOW}⚠ node_modules not found. Running npm install...${NC}"
npm install
echo -e "${GREEN}✓ Dependencies installed${NC}"
echo ""
fi
echo -e "${GREEN}Starting services...${NC}"
echo ""
# 1. Start FastAPI backend (Rerun will start when recording begins)
echo -e "${BLUE}[1/2]${NC} Starting FastAPI backend on port 8000..."
cd "$SCRIPT_DIR"
# Use Python from current environment (if lerobot env is active, it will use that)
# Otherwise, check if we need to use conda run
if [[ "$CONDA_DEFAULT_ENV" == "lerobot" ]]; then
# Already in lerobot environment
echo -e "${GREEN}✓ Using active lerobot environment${NC}"
PYTHON_CMD="python"
elif command -v conda >/dev/null 2>&1 && conda env list | grep -q "^lerobot "; then
# lerobot env exists but not active - use conda run
echo -e "${YELLOW}Using conda run with lerobot environment...${NC}"
PYTHON_CMD="conda run -n lerobot --no-capture-output python"
else
# Fall back to system python
echo -e "${YELLOW}⚠ Warning: lerobot environment not found, using system python${NC}"
PYTHON_CMD="python"
fi
$PYTHON_CMD web_record_server.py > /tmp/openarms_backend.log 2>&1 &
BACKEND_PID=$!
sleep 3
if ps -p $BACKEND_PID > /dev/null; then
echo -e "${GREEN}✓ Backend started${NC} (PID: $BACKEND_PID)"
echo -e " URL: ${BLUE}http://localhost:8000${NC}"
else
echo -e "${RED}✗ Failed to start backend${NC}"
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_backend.log${NC}"
exit 1
fi
echo ""
# 2. Start React frontend
echo -e "${BLUE}[2/2]${NC} Starting React frontend on port 5173..."
cd "$SCRIPT_DIR"
npm run dev > /tmp/openarms_frontend.log 2>&1 &
FRONTEND_PID=$!
sleep 3
if ps -p $FRONTEND_PID > /dev/null; then
echo -e "${GREEN}✓ Frontend started${NC} (PID: $FRONTEND_PID)"
echo -e " URL: ${BLUE}http://localhost:5173${NC}"
else
echo -e "${RED}✗ Failed to start frontend${NC}"
echo -e "${YELLOW}Check logs: tail -f /tmp/openarms_frontend.log${NC}"
exit 1
fi
echo ""
# Display status
echo -e "${GREEN}╔════════════════════════════════════════╗${NC}"
echo -e "${GREEN}║ All services running! 🚀 ║${NC}"
echo -e "${GREEN}╚════════════════════════════════════════╝${NC}"
echo ""
echo -e "🔧 ${BLUE}Backend:${NC} http://localhost:8000"
echo -e "🌐 ${BLUE}Frontend:${NC} http://localhost:5173"
echo -e "📊 ${BLUE}Rerun:${NC} Will spawn automatically when recording starts"
echo ""
echo -e "${YELLOW}Open your browser to:${NC} ${BLUE}http://localhost:5173${NC}"
echo ""
echo -e "${YELLOW}Logs:${NC}"
echo -e " • Backend: tail -f /tmp/openarms_backend.log"
echo -e " • Frontend: tail -f /tmp/openarms_frontend.log"
echo ""
echo -e "${RED}Press Ctrl+C to stop all services${NC}"
echo ""
# Keep script running and wait for any service to exit
wait
-7
View File
@@ -1,7 +0,0 @@
import { createRoot } from 'react-dom/client'
import App from './App.jsx'
createRoot(document.getElementById('root')).render(
<App />
)
File diff suppressed because it is too large Load Diff
@@ -1,21 +0,0 @@
{
"name": "openarms-web-interface",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"preview": "vite preview"
},
"dependencies": {
"react": "^18.3.1",
"react-dom": "^18.3.1"
},
"devDependencies": {
"@types/react": "^18.3.12",
"@types/react-dom": "^18.3.1",
"@vitejs/plugin-react": "^4.3.4",
"vite": "^6.0.1"
}
}
@@ -1,17 +0,0 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vite.dev/config/
export default defineConfig({
plugins: [react()],
server: {
port: 5173,
strictPort: false,
host: true,
open: false
},
build: {
outDir: 'dist',
sourcemap: true
}
})
File diff suppressed because it is too large Load Diff
@@ -15,16 +15,12 @@
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
from port_droid import DROID_SHARDS
class AggregateDatasets(PipelineStep):
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
+2 -2
View File
@@ -20,7 +20,7 @@ from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
from port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
+9 -3
View File
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
UploadDataset(repo_id, private=private),
],
"logging_dir": str(logs_dir / job_name),
}
@@ -267,6 +267,12 @@ def main():
default="1950M",
help="Memory per cpu that each worker will use.",
)
parser.add_argument(
"--private",
action="store_true",
default=False,
help="Whether to create a private repository.",
)
init_logging()
+951
View File
@@ -0,0 +1,951 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
This script takes two random samples from a dataset:
- Uses actions from the first sample as previous chunk
- Generates new actions for the second sample with and without RTC
It compares action predictions with and without RTC on dataset samples,
measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=EXP \
--seed=10
# Basic usage with pi0.5 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--device=mps
--seed=10
# Basic usage with pi0.5 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# Basic usage with pi0 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
--torch_compile_mode=reduce-overhead
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
"""
import gc
import logging
import os
import random
from dataclasses import dataclass, field
import numpy as np
import torch
try:
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
plt = None
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.factory import resolve_delta_timestamps
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def _check_matplotlib_available():
"""Check if matplotlib is available, raise helpful error if not."""
if not MATPLOTLIB_AVAILABLE:
raise ImportError(
"matplotlib is required for RTC debug visualizations. "
"Please install it by running:\n"
" uv pip install matplotlib"
)
@dataclass
class RTCEvalConfig(HubMixin):
"""Configuration for RTC evaluation."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Dataset configuration
dataset: DatasetConfig = field(default_factory=DatasetConfig)
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=10.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=True,
debug_maxlen=1000,
)
)
# Device configuration
device: str | None = field(
default=None,
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
)
# Output configuration
output_dir: str = field(
default="rtc_debug_output",
metadata={"help": "Directory to save debug visualizations"},
)
# Seed configuration
seed: int = field(
default=42,
metadata={"help": "Random seed for reproducibility"},
)
inference_delay: int = field(
default=4,
metadata={"help": "Inference delay for RTC"},
)
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required (--policy.path)")
# Auto-detect device if not specified
if self.device is None or self.device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
logging.info(f"Auto-detected device: {self.device}")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
class RTCEvaluator:
"""Evaluator for RTC on dataset samples."""
def __init__(self, cfg: RTCEvalConfig):
self.cfg = cfg
self.device = cfg.device
# Load dataset with proper delta_timestamps based on policy configuration
# Calculate delta_timestamps using the same logic as make_dataset factory
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
# Get dataset metadata to extract FPS
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
# Calculate delta_timestamps from policy's delta_indices
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
# Create dataset with calculated delta_timestamps
self.dataset = LeRobotDataset(
cfg.dataset.repo_id,
delta_timestamps=delta_timestamps,
)
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor
self.preprocessor, self.postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides={
"device_processor": {"device": self.device},
},
)
logging.info("=" * 80)
logging.info("Ready to run evaluation with sequential policy loading:")
logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
"""Initialize a single policy instance with specified RTC configuration.
Args:
name: Name identifier for logging purposes
rtc_enabled: Whether to enable RTC for this policy
rtc_debug: Whether to enable debug tracking for this policy
Returns:
Configured policy instance with optional torch.compile applied
"""
logging.info(f"Initializing {name}...")
# Load policy from pretrained
policy_class = get_policy_class(self.cfg.policy.type)
config = PreTrainedConfig.from_pretrained(self.cfg.policy.pretrained_path)
if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0":
config.compile_model = self.cfg.use_torch_compile
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config)
policy = policy.to(self.device)
policy.eval()
# Configure RTC
rtc_config = RTCConfig(
enabled=rtc_enabled,
execution_horizon=self.cfg.rtc.execution_horizon,
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
debug=rtc_debug,
debug_maxlen=self.cfg.rtc.debug_maxlen,
)
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
logging.info(f" RTC enabled: {rtc_enabled}")
logging.info(f" RTC debug: {rtc_debug}")
logging.info(f" Policy config: {config}")
# Apply torch.compile to predict_action_chunk method if enabled
if self.cfg.use_torch_compile:
policy = self._apply_torch_compile(policy, name)
logging.info(f"{name} initialized successfully")
return policy
def _apply_torch_compile(self, policy, policy_name: str):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
policy_name: Name for logging purposes
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logging.warning(
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
# Compile the predict_action_chunk method
# - Debug tracker is excluded from compilation via @torch._dynamo.disable
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": self.cfg.torch_compile_backend,
"mode": self.cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if self.cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
except Exception as e:
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
logging.warning(f" [{policy_name}] Continuing without torch.compile")
return policy
def _destroy_policy(self, policy, policy_name: str):
"""Explicitly destroy a policy and free all associated memory.
This method performs aggressive cleanup to ensure maximum memory is freed,
which is critical for large models (e.g., VLAs with billions of parameters).
Args:
policy: Policy instance to destroy
policy_name: Name for logging purposes
"""
logging.info(f" Destroying {policy_name} and freeing memory...")
try:
# Step 1: Move policy to CPU to free GPU/MPS memory
policy.cpu()
# Step 2: Delete the policy object
del policy
# Step 3: Force garbage collection to reclaim memory immediately
gc.collect()
# Step 4: Clear device-specific caches
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Ensure all operations complete
if torch.backends.mps.is_available():
torch.mps.empty_cache()
logging.info(f"{policy_name} destroyed and memory freed")
except Exception as e:
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
def run_evaluation(self):
"""Run evaluation on two random dataset samples using three separate policies.
Note: Policies are deinitalized after each step to free memory. Large models
(e.g., VLA models with billions of parameters) cannot fit three instances in
memory simultaneously. By deleting and garbage collecting after each step,
we ensure only one policy is loaded at a time.
"""
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logging.info(f"Output directory: {self.cfg.output_dir}")
logging.info("=" * 80)
logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}")
logging.info("=" * 80)
# Load two random samples from dataset
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader)
first_sample = next(loader_iter)
second_sample = next(loader_iter)
preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample)
# ============================================================================
# Step 1: Generate previous chunk using policy_prev_chunk
# ============================================================================
# This policy is only used to generate the reference chunk and then freed
logging.info("=" * 80)
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
logging.info("=" * 80)
# Initialize policy 1
policy_prev_chunk_policy = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
with torch.no_grad():
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
# Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
# ============================================================================
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 2
policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone()
policy_no_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
no_rtc_actions = policy_no_rtc_policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise,
)
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}")
# Destroy policy_no_rtc to free memory before loading policy_rtc
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
# ============================================================================
# Step 3: Generate actions WITH RTC using policy_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 3
policy_rtc_policy = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
policy_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
rtc_actions = policy_rtc_policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise_clone,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
)
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}")
# Save num_steps before destroying policy (needed for plotting)
try:
num_steps = policy_rtc_policy.config.num_steps
except Exception as e:
logging.error(f" Error getting num_steps: {e}")
num_steps = policy_rtc_policy.config.num_inference_steps
logging.warning(f" Using num_inference_steps: {num_steps} instead of num_steps")
# Destroy policy_rtc after final use
self._destroy_policy(policy_rtc_policy, "policy_rtc")
# Plot and save results
logging.info("=" * 80)
logging.info("Plotting results...")
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
# Plot final actions comparison
logging.info("=" * 80)
logging.info("Plotting final actions comparison...")
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
logging.info("=" * 80)
logging.info("Evaluation completed successfully")
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
"""Plot final action predictions comparison on a single chart.
Args:
rtc_actions: Final actions from RTC policy
no_rtc_actions: Final actions from non-RTC policy
prev_chunk_left_over: Previous chunk used as ground truth
"""
_check_matplotlib_available()
# Remove batch dimension if present
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
no_rtc_actions_plot = (
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
)
prev_chunk_plot = prev_chunk_left_over.cpu()
# Create figure with 6 subplots (one per action dimension)
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
# Plot each action dimension
for dim_idx, ax in enumerate(axes):
# Plot previous chunk (ground truth) in red
RTCDebugVisualizer.plot_waypoints(
[ax],
prev_chunk_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="red",
label="Previous Chunk (Ground Truth)",
linewidth=2.5,
alpha=0.8,
)
# Plot no-RTC actions in blue
RTCDebugVisualizer.plot_waypoints(
[ax],
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="blue",
label="No RTC",
linewidth=2,
alpha=0.7,
)
# Plot RTC actions in green
RTCDebugVisualizer.plot_waypoints(
[ax],
rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="green",
label="RTC",
linewidth=2,
alpha=0.7,
)
# Add vertical lines for inference delay and execution horizon
inference_delay = self.cfg.inference_delay
execution_horizon = self.cfg.rtc.execution_horizon
if inference_delay > 0:
ax.axvline(
x=inference_delay - 1,
color="orange",
linestyle="--",
alpha=0.5,
label=f"Inference Delay ({inference_delay})",
)
if execution_horizon > 0:
ax.axvline(
x=execution_horizon,
color="purple",
linestyle="--",
alpha=0.5,
label=f"Execution Horizon ({execution_horizon})",
)
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
# Set x-axis ticks to show all integer values
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
axes[-1].set_xlabel("Step", fontsize=10)
# Collect legend handles and labels from first subplot
handles, labels = axes[0].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right)
fig.legend(
unique_handles,
unique_labels,
loc="center right",
fontsize=9,
bbox_to_anchor=(1.0, 0.5),
framealpha=0.9,
)
# Save figure
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
fig.savefig(output_path, dpi=150, bbox_inches="tight")
logging.info(f"Saved final actions comparison to {output_path}")
plt.close(fig)
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
_check_matplotlib_available()
# Create side-by-side figures for denoising visualization
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
fig_x1t, axs_x1t = self._create_figure(
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
)
self._plot_denoising_steps_from_tracker(
rtc_tracked_steps,
axs_xt[:, 1], # Right column for x_t
axs_vt[:, 1], # Right column for v_t
axs_corr[:, 1], # Right column for correction
axs_x1t[:, 1], # Right column for x1_t
num_steps,
add_labels=True, # Add labels for RTC (right column)
)
self._plot_denoising_steps_from_tracker(
no_rtc_tracked_steps,
axs_xt[:, 0], # Left column for x_t
axs_vt[:, 0], # Left column for v_t
axs_corr[:, 0], # Left column for correction
axs_x1t[:, 0], # Left column for x1_t
num_steps,
add_labels=False, # No labels for No RTC (left column)
)
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x_t axes (no labels for left column)
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
)
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
)
# Add legends outside the plot area for each figure
self._add_figure_legend(fig_xt, axs_xt)
self._add_figure_legend(fig_vt, axs_vt)
self._add_figure_legend(fig_corr, axs_corr)
self._add_figure_legend(fig_x1t, axs_x1t)
# Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
def _create_figure(self, title):
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
fig.suptitle(title, fontsize=16)
for ax in axs[:, 0]:
ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
for ax in axs[:, 1]:
ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
return fig, axs
def _add_figure_legend(self, fig, axs):
"""Add a legend outside the plot area on the right side.
Args:
fig: Matplotlib figure to add legend to
axs: Array of axes to collect legend handles from
"""
# Collect all handles and labels from the first row of axes (right column)
handles, labels = axs[0, 1].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right, close to charts)
if unique_handles:
fig.legend(
unique_handles,
unique_labels,
loc="center left",
fontsize=8,
bbox_to_anchor=(0.87, 0.5),
framealpha=0.9,
ncol=1,
)
def _save_figure(self, fig, path):
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
fig.savefig(path, dpi=150, bbox_inches="tight")
logging.info(f"Saved figure to {path}")
plt.close(fig)
def _plot_denoising_steps_from_tracker(
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
):
"""Plot denoising steps from tracker data.
Args:
tracked_steps: List of DebugStep objects containing debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap
add_labels: Whether to add legend labels for the plots
"""
logging.info("=" * 80)
logging.info(f"Plotting {len(tracked_steps)} steps")
debug_steps = tracked_steps
if not debug_steps:
return
# Define colors for different denoise steps (using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)]
label = f"Step {step_idx}" if add_labels else None
# Plot x_t
if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=label
)
# Plot v_t
if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=label
)
# Plot correction on separate axes
if debug_step.correction is not None:
RTCDebugVisualizer.plot_waypoints(
corr_axs,
debug_step.correction,
start_from=0,
color=color,
label=label,
)
# Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None:
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
RTCDebugVisualizer.plot_waypoints(
x1t_axs,
debug_step.x1_t,
start_from=0,
color=color,
label=x1t_label,
)
# Plot error in orange dashed
if x1t_axs is not None and debug_step.err is not None:
error_chunk = (
debug_step.err[0].cpu().numpy()
if len(debug_step.err.shape) == 3
else debug_step.err.cpu().numpy()
)
num_dims = min(error_chunk.shape[-1], 6)
error_label = f"error Step {step_idx}" if add_labels else None
for j in range(num_dims):
x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
error_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
label=error_label,
)
# Recalculate axis limits after plotting to ensure proper scaling
self._rescale_axes(xt_axs)
self._rescale_axes(vt_axs)
self._rescale_axes(corr_axs)
self._rescale_axes(x1t_axs)
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
Args:
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
num_steps: Total number of denoising steps for colormap
"""
debug_steps = no_rtc_tracked_steps
if not debug_steps:
return
# Plot only the final x_t step as orange dashed line
final_step = debug_steps[-1]
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
if final_step.x_t is not None:
x_t_chunk = (
final_step.x_t[0].cpu().numpy()
if len(final_step.x_t.shape) == 3
else final_step.x_t.cpu().numpy()
)
num_dims = min(x_t_chunk.shape[-1], 6)
for j in range(num_dims):
xt_axs[j].plot(
np.arange(0, x_t_chunk.shape[0]),
x_t_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
linewidth=2,
label="No RTC (final)" if j == 0 else "",
)
def _rescale_axes(self, axes):
"""Rescale axes to show all data with proper margins.
Args:
axes: Array of matplotlib axes to rescale
"""
for ax in axes:
ax.relim()
ax.autoscale_view()
# Add 10% margin to y-axis for better visualization
ylim = ax.get_ylim()
y_range = ylim[1] - ylim[0]
if y_range > 0: # Avoid division by zero
margin = y_range * 0.1
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
# Set x-axis ticks to show all integer values
xlim = ax.get_xlim()
max_len = int(xlim[1]) + 1
if max_len > 0:
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
@parser.wrap()
def main(cfg: RTCEvalConfig):
"""Main entry point for RTC evaluation."""
# Set random seed for reproducibility
set_seed(cfg.seed)
init_logging()
logging.info("=" * 80)
logging.info("RTC Dataset Evaluation")
logging.info(f"Config: {cfg}")
logging.info("=" * 80)
evaluator = RTCEvaluator(cfg)
evaluator.run_evaluation()
if __name__ == "__main__":
main()
+549
View File
@@ -0,0 +1,549 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
-10
View File
@@ -1,10 +0,0 @@
from huggingface_hub import HfApi, list_datasets
api = HfApi()
datasets = list_datasets(author="lerobot-data-collection")
print('"[', end="")
i=0
for dataset in datasets:
if "three-folds-dataset" in dataset.id:
print("'" + dataset.id + "',", end="")
print(']"',)
+2 -5
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.1"
version = "0.4.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -102,10 +102,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (com
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
@@ -144,13 +142,12 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
# All
all = [
"lerobot[dynamixel]",
"lerobot[openarms]",
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
+761
View File
@@ -0,0 +1,761 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for SARM (Stage-Aware Reward Model).
This script loads a trained SARM model and runs inference on a dataset episode,
generating visualizations of the predicted task stages and progress over time.
Example usage:
python scripts/visualize_sarm_predictions.py \
--model-id username/sarm-model \
--dataset-repo lerobot/aloha_sim_insertion_human \
--episode-index 0 \
--output-dir outputs/sarm_viz \
--task-description "insert the peg into the socket"
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
from lerobot.policies.sarm.sarm_utils import (
pad_state_to_max_dim,
compute_tau,
compute_cumulative_progress_batch,
)
from lerobot.datasets.utils import load_stats
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
# Model arguments
parser.add_argument(
"--model-id",
type=str,
required=True,
help="HuggingFace model ID or local path to trained SARM model"
)
# Dataset arguments
parser.add_argument(
"--dataset-repo",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
)
parser.add_argument(
"--episode-index",
type=int,
default=0,
help="Index of the episode to visualize (default: 0)"
)
parser.add_argument(
"--task-description",
type=str,
default="perform the task",
help="Task description for the reward model (default: 'perform the task')"
)
# Output arguments
parser.add_argument(
"--output-dir",
type=str,
default="outputs/sarm_inference",
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
)
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
)
parser.add_argument(
"--state-key",
type=str,
default=None,
help="Key for joint states in dataset. If None, auto-detects from dataset"
)
# Visualization options
parser.add_argument(
"--show-frames",
action="store_true",
help="Include sample frames in the visualization"
)
parser.add_argument(
"--num-sample-frames",
type=int,
default=8,
help="Number of sample frames to show (default: 8)"
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[14, 8],
help="Figure size as width height (default: 14 8)"
)
# Device
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to run inference on (cuda/cpu, default: auto-detect)"
)
return parser.parse_args()
def load_episode_data(
dataset: LeRobotDataset,
episode_index: int,
image_key: str,
state_key: str | None = None
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
"""
Load all frames and states from a specific episode.
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode to load
image_key: Key for accessing images in the dataset
state_key: Key for accessing joint states (auto-detected if None)
Returns:
Tuple of (frames, states, start_index, end_index, task_description)
"""
# Get episode boundaries
episode_data = dataset.meta.episodes
start_idx = episode_data["dataset_from_index"][episode_index]
end_idx = episode_data["dataset_to_index"][episode_index]
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
# Auto-detect state key if not provided
if state_key is None:
first_item = dataset[start_idx]
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
if state_keys:
state_key = state_keys[0]
logger.info(f"Auto-detected state key: {state_key}")
# Get task description from the dataset if available
task_description = None
first_item = dataset[start_idx]
if "task" in first_item:
task_description = first_item["task"]
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
# Load all frames and states from the episode
frames = []
states = []
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
item = dataset[idx]
# Get image
img = item[image_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Handle different image formats (C, H, W) or (H, W, C)
if img.shape[0] in [1, 3]: # Channel first
img = np.transpose(img, (1, 2, 0))
# Convert to uint8 if needed
if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
frames.append(img)
# Get state if available
if state_key and state_key in item:
state = item[state_key]
if isinstance(state, torch.Tensor):
state = state.cpu().numpy()
states.append(state)
frames = np.array(frames)
states = np.array(states) if states else None
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
if states is not None:
logger.info(f"Loaded states with shape {states.shape}")
return frames, states, start_idx, end_idx, task_description
@torch.no_grad()
def run_inference(
model: SARMRewardModel,
frames: np.ndarray,
states: Optional[np.ndarray],
task_description: str,
dataset_stats: dict | None = None,
state_key: str = "observation.state",
batch_size: int = 32
) -> tuple[np.ndarray, np.ndarray]:
"""
Run SARM inference on video frames and joint states.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args:
model: SARM model
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim)
task_description: Task description text
dataset_stats: Dataset statistics for state normalization (same as training)
state_key: Key for state in dataset_stats
batch_size: Batch size for processing slices
Returns:
Tuple of (progress_predictions, stage_predictions)
- progress_predictions: (num_frames,)
- stage_predictions: (num_frames, num_stages)
"""
logger.info("Encoding video frames with CLIP...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with CLIP...")
text_embedding = model.encode_text(task_description)
# Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
if states is not None:
state_embeddings = torch.tensor(states, dtype=torch.float32)
# Normalize states using dataset stats (same as training processor)
if dataset_stats is not None and state_key in dataset_stats:
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
else:
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
else:
state_embeddings = None
video_slices = []
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using symmetric bidirectional pattern:
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
# Boundary handling: clamp to [0, last_valid]
deltas = model.config.observation_delta_indices
last_valid = len(video_embeddings) - 1
frame_indices = []
for delta in deltas:
idx = current_frame + delta
idx = max(0, min(idx, last_valid)) # Clamp to valid range
frame_indices.append(idx)
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
if state_embeddings is not None:
state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice)
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None:
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
# Pad states to max_state_dim (same as training processor)
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
else:
state_slices = None
logger.info("Running SARM inference on all slices...")
# Process in batches
all_progress = []
all_stages = []
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
batch_video = video_slices[i:i + batch_size].to(model.device)
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
batch_size_actual = batch_video.shape[0]
# Replicate text embedding for batch
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
# Get predictions
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
batch_video, batch_text, batch_states
)
# Extract predictions at the "current frame" position
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
# the current frame is at position 5 (0-indexed)
current_frame_idx = 5
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
all_progress.extend(batch_progress)
all_stages.extend(batch_stages)
return np.array(all_progress), np.array(all_stages)
def compute_ground_truth_progress(
dataset: LeRobotDataset,
episode_index: int,
temporal_proportions: dict[str, float],
subtask_names_ordered: list[str],
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""
Compute ground truth progress and stage labels for an episode using annotations.
Uses SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode
temporal_proportions: Dict mapping subtask name to proportion
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
Returns:
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
"""
# Load episode metadata
episodes_df = dataset.meta.episodes.to_pandas()
# Check if annotations exist
if "subtask_names" not in episodes_df.columns:
logger.warning("No subtask_names column found in episodes metadata")
return None, None
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
logger.warning(f"No annotations found for episode {episode_index}")
return None, None
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
# Get episode boundaries
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
num_frames = ep_end - ep_start
# Get temporal proportions as ordered list
temporal_proportions_list = [
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
]
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
logger.info(f"Subtask names in episode: {ep_subtask_names}")
logger.info(f"Subtask start frames: {subtask_start_frames}")
logger.info(f"Subtask end frames: {subtask_end_frames}")
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
# Compute ground truth for each frame
gt_progress = np.zeros(num_frames)
gt_stages = np.zeros(num_frames, dtype=np.int32)
for frame_rel in range(num_frames):
# Find which subtask this frame belongs to
found = False
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
if frame_rel >= start_frame and frame_rel <= end_frame:
# Found the subtask - get its global index
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
# Compute τ_t using utility function
tau = compute_tau(frame_rel, start_frame, end_frame)
# Compute cumulative progress using utility function
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
found = True
break
if not found:
# Handle frames outside annotated subtasks
if frame_rel < subtask_start_frames[0]:
gt_progress[frame_rel] = 0.0
gt_stages[frame_rel] = 0
elif frame_rel > subtask_end_frames[-1]:
gt_progress[frame_rel] = 1.0
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
else:
# Between subtasks - find previous subtask
for j in range(len(ep_subtask_names) - 1):
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
name = ep_subtask_names[j]
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
break
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
return gt_progress, gt_stages
def visualize_predictions(
frames: np.ndarray,
progress_predictions: np.ndarray,
stage_predictions: np.ndarray,
task_description: str,
output_path: Path,
num_sample_frames: int = 8,
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None,
ground_truth_progress: np.ndarray | None = None,
ground_truth_stages: np.ndarray | None = None,
):
"""
Create visualization of SARM predictions with optional ground truth comparison.
Args:
frames: Video frames (num_frames, H, W, C)
progress_predictions: Progress predictions (num_frames,)
stage_predictions: Stage probabilities (num_frames, num_stages)
task_description: Task description
output_path: Path to save the figure
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
ground_truth_progress: Optional ground truth progress array (num_frames,)
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
"""
num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
# Use subtask names if available, otherwise use generic labels
if subtask_names is not None and len(subtask_names) == num_stages:
stage_labels = subtask_names
else:
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
# Create figure with progress plot, stage plot, and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
ax_progress = fig.add_subplot(gs[0])
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
ax_frames = fig.add_subplot(gs[2])
frame_indices = np.arange(len(progress_predictions))
# Plot 1: Progress over time
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
# Plot ground truth if available
if ground_truth_progress is not None:
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
linestyle='--', label='Ground Truth Progress')
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
# Add statistics box
stats_text = (
f'Frames: {len(progress_predictions)}\n'
f'Final Progress: {progress_predictions[-1]:.3f}\n'
f'Max Progress: {progress_predictions.max():.3f}\n'
f'Mean Progress: {progress_predictions.mean():.3f}'
)
if ground_truth_progress is not None:
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
stats_text += f'\nMSE vs GT: {mse:.4f}'
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=stage_labels)
# Plot ground truth stage as vertical bands or markers
if ground_truth_stages is not None:
# Find stage transition points in ground truth
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
for change_idx in stage_changes:
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
# Add small markers at bottom showing GT stage
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
c=[stage_colors[s] for s in ground_truth_stages[::30]],
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1)
ax_stages.grid(True, alpha=0.3)
# Adjust legend based on number of stages and label lengths
if num_stages <= 5:
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
else:
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
if temporal_proportions is not None and subtask_names is not None:
cumulative_progress = 0.0
for i, name in enumerate(stage_labels):
if name in temporal_proportions:
# Find approximate frame where this stage should end
stage_end_progress = cumulative_progress + temporal_proportions[name]
# Find frame index closest to this progress
progress_diffs = np.abs(progress_predictions - stage_end_progress)
stage_end_frame = np.argmin(progress_diffs)
# Draw vertical line
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested)
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved visualization to {output_path}")
plt.close()
def main():
args = parse_args()
# Setup device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
logger.info(f"Using device: {device}")
# Load model
logger.info(f"Loading SARM model from {args.model_id}...")
model = SARMRewardModel.from_pretrained(args.model_id)
model.to(device)
model.eval()
logger.info("Model loaded successfully")
# Load dataset
logger.info(f"Loading dataset {args.dataset_repo}...")
dataset = LeRobotDataset(args.dataset_repo)
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
# Validate episode index
if args.episode_index >= len(dataset.meta.episodes):
raise ValueError(
f"Episode index {args.episode_index} out of range. "
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
image_key = args.image_key if args.image_key is not None else model.config.image_key
state_key = args.state_key if args.state_key is not None else model.config.state_key
logger.info(f"Using image key: {image_key}")
logger.info(f"Using state key: {state_key}")
# Load dataset stats for state normalization (same as training)
dataset_stats = load_stats(dataset.root)
if dataset_stats:
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
else:
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
# Load episode data
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
dataset, args.episode_index, image_key, state_key
)
# Use task description from dataset if available, otherwise use command-line argument
task_description = dataset_task if dataset_task is not None else args.task_description
logger.info(f"Using task description: '{task_description}'")
# Run inference
progress_predictions, stage_predictions = run_inference(
model, frames, states, task_description,
dataset_stats=dataset_stats, state_key=state_key
)
# Extract subtask names and temporal proportions from model config if available
subtask_names = None
temporal_proportions = None
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model config
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
temporal_proportions = {
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
}
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None:
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Also extract subtask names from proportions if not already set
if subtask_names is None:
subtask_names = sorted(temporal_proportions.keys())
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
# Compute ground truth progress if annotations are available
ground_truth_progress = None
ground_truth_stages = None
if temporal_proportions is not None and subtask_names is not None:
logger.info("Attempting to compute ground truth progress from annotations...")
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
dataset,
args.episode_index,
temporal_proportions,
subtask_names
)
if ground_truth_progress is None:
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
else:
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
visualize_predictions(
frames,
progress_predictions,
stage_predictions,
task_description,
output_path,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions,
ground_truth_progress=ground_truth_progress,
ground_truth_stages=ground_truth_stages,
)
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
save_dict = {
'progress': progress_predictions,
'stages': stage_predictions
}
if ground_truth_progress is not None:
save_dict['gt_progress'] = ground_truth_progress
save_dict['gt_stages'] = ground_truth_stages
np.savez(predictions_path, **save_dict)
logger.info(f"Saved predictions to {predictions_path}")
logger.info(f"\nVisualization: {output_path}")
if __name__ == "__main__":
main()
+19 -2
View File
@@ -64,9 +64,26 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM)
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
def validate(self):
# Validate RA-BC configuration
if self.use_rabc and not self.reward_model_path:
raise ValueError(
"RA-BC is enabled (use_rabc=True) but no reward_model_path provided. "
"Please specify a pre-trained reward model (e.g., SARM) path."
)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
+7
View File
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
class PolicyFeature:
type: FeatureType
shape: tuple[int, ...]
class RTCAttentionSchedule(str, Enum):
ZEROS = "ZEROS"
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
+25 -21
View File
@@ -39,6 +39,7 @@ from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -962,28 +963,23 @@ def _copy_data_with_feature_changes(
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.meta.root)
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
# Map file paths to episode indices to extract chunk/file indices
file_to_episodes: dict[Path, set[int]] = {}
for ep_idx in range(dataset.meta.total_episodes):
file_path = dataset.meta.get_data_file_path(ep_idx)
if file_path not in file_to_episodes:
file_to_episodes[file_path] = set()
file_to_episodes[file_path].add(ep_idx)
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
frame_idx = 0
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Get chunk_idx and file_idx from the source file's first episode
episodes_in_file = file_to_episodes[src_path]
first_ep_idx = min(episodes_in_file)
src_ep = dataset.meta.episodes[first_ep_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
relative_path = src_path.relative_to(dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
@@ -1003,13 +999,21 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_values
else:
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
else:
if len(feature_slice.shape) == 1:
# 1D array - can assign directly
df[feature_name] = feature_slice
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
# 2D array with single column - flatten it
df[feature_name] = feature_slice.flatten()
elif len(feature_slice.shape) == 2:
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
df[feature_name] = feature_slice.tolist()
else:
# Higher dimensional - convert to list
df[feature_name] = [row.tolist() for row in feature_slice]
frame_idx = end_idx
# Write using the preserved chunk_idx and file_idx from source
# Write using the same chunk/file structure as source
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
@@ -0,0 +1,146 @@
# LeRobot Embedding Generation Script
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
## Overview
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
- **Task embeddings**: Language command embeddings using MiniLM
- **Image embeddings**: Frame embeddings using DinoV2
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
## Supported Encoders
### Image Encoders (DinoV2)
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
### Language Encoders (MiniLM)
MiniLM is a lightweight sentence transformer model:
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
## Usage
### Basic Command
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--push-to-hub
```
### Lightweight Version (No Videos)
Removes video files to significantly reduce storage:
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
```
## Output
The script adds new features to your dataset:
### New Features
1. **`task_embedding`**: Language embedding for each frame
- Shape: `[384]` (MiniLM)
- One embedding per frame based on its task
2. **`{camera_key}_embedding`**: Image embedding for each camera view
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
### Using Embeddings in Training
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load dataset with embeddings
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
# Access embeddings
item = dataset[0]
task_emb = item["task_embedding"] # Shape: [384]
img_emb = item["observation.images.top_embedding"] # Shape: [768]
# Use in your policy
# Instead of running encoders during training, use pre-computed embeddings
```
## Extending with New Encoders
The script is designed to be easily extensible. To add a new encoder:
### 1. Create Encoder Class
```python
class MyCustomImageEncoder(ImageEncoder):
"""Your custom image encoder."""
def __init__(self, device: str = "cuda"):
super().__init__(device)
# Load your model
self.model = load_my_model()
self.model = self.model.to(self.device)
self.model.eval()
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
# Your encoding logic here
embeddings = []
for img in images:
emb = self.model(img)
embeddings.append(emb)
return np.array(embeddings)
@property
def embedding_dim(self) -> int:
"""Return embedding dimension."""
return 512 # Your embedding dimension
```
### 2. Add to Factory Function
```python
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
# Add your encoder
"my_custom": lambda: MyCustomImageEncoder(device=device),
}
# ... rest of function
```
## Validating Embeddings
After generating embeddings, you can validate them using `validate_embeddings.py`:
```bash
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 20
```
@@ -0,0 +1,147 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageEncoder:
"""Base class for image encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
raise NotImplementedError
class DinoV2Encoder(ImageEncoder):
"""DinoV2 image encoder.
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
"""
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
super().__init__(device)
self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
self.model = self.model.to(self.device)
self.model.eval()
# DinoV2 preprocessing
from torchvision import transforms
self.transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
embeddings = []
with torch.inference_mode():
for i in range(0, len(images), self.batch_size):
batch_images = images[i : i + self.batch_size]
# Convert numpy arrays to PIL Images and apply transforms
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
# Get embeddings
batch_embeddings = self.model(tensors).cpu().numpy()
embeddings.append(batch_embeddings)
return np.concatenate(embeddings, axis=0)
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension based on model size."""
if "vits14" in self.model_name:
return 384 # DinoV2 ViT-S/14
elif "vitb14" in self.model_name:
return 768 # DinoV2 ViT-B/14
elif "vitl14" in self.model_name:
return 1024 # DinoV2 ViT-L/14
else:
return 768 # Default to ViT-B/14
class LanguageEncoder:
"""Base class for language encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
raise NotImplementedError
class MiniLMEncoder(LanguageEncoder):
"""MiniLM language encoder.
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
Supports L6 and L12 model sizes.
"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
super().__init__(device)
self.model_name = model_name
logger.info(f"Loading MiniLM model: {model_name}")
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
def _mean_pooling(self, model_output, attention_mask):
"""Mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
with torch.inference_mode():
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
model_output = self.model(**encoded_input)
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
return embeddings.cpu().numpy()
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension."""
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings
@@ -0,0 +1,329 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
This script:
1. Loads a v3.0 LeRobot dataset from the hub
2. Computes embeddings for tasks (language commands) and frames (images)
3. Stores embeddings as new features in the dataset
4. Optionally removes video files to reduce size
5. Pushes the converted dataset to the hub
Current supported encoders:
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
- Language: MiniLM (minilm-l6, minilm-l12)
The architecture is extensible - you can add more encoders by:
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
2. Implementing the encode() method and embedding_dim property
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
Usage example:
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
"""
import argparse
import shutil
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import (
DinoV2Encoder,
ImageEncoder,
LanguageEncoder,
MiniLMEncoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
"""Factory function to get image encoder.
To add a new encoder:
1. Create a new class inheriting from ImageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
}
if encoder_name not in encoders:
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
return encoders[encoder_name]()
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
"""Factory function to get language encoder.
To add a new encoder:
1. Create a new class inheriting from LanguageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"minilm-l6": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
),
"minilm-l12": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
),
}
if encoder_name not in encoders:
raise ValueError(
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
)
return encoders[encoder_name]()
def generate_embeddings_for_dataset(
repo_id: str,
output_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
remove_videos: bool = False,
local_dir: Path | None = None,
output_local_dir: Path | None = None,
push_to_hub: bool = False,
):
"""Generate embeddings for a LeRobot dataset.
Args:
repo_id: Source dataset repository ID
output_repo_id: Output dataset repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
remove_videos: Whether to remove video files
local_dir: Local directory for source dataset
output_local_dir: Local directory for output dataset
push_to_hub: Whether to push to hub after conversion
"""
from lerobot.datasets.dataset_tools import modify_features
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
print("Computing task embeddings...")
unique_tasks = dataset.meta.tasks.index.tolist()
task_embeddings = {}
for task in tqdm(unique_tasks, desc="Encoding tasks"):
# Clean up task text
task_clean = task.strip().capitalize().strip(" .,!?-_")
embedding = language_encoder.encode([task_clean])[0]
task_embeddings[task] = embedding
print(f"Computed {len(task_embeddings)} task embeddings")
print("Processing frames and computing embeddings...")
all_task_embeddings = []
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
item = dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)
for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 4:
img = img[0] # (T, C, H, W) -> (C, H, W)
elif img.ndim != 3:
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
img = item[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
all_image_embeddings_dict[cam_key].append(img_np)
print("Computing image embeddings...")
image_embeddings_dict = {}
for cam_key, images in all_image_embeddings_dict.items():
print(f" {cam_key}: {len(images)} images")
embeddings = image_encoder.encode(images)
image_embeddings_dict[cam_key] = embeddings
all_task_embeddings = np.array(all_task_embeddings)
for cam_key in dataset.meta.camera_keys:
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim
add_features_dict = {
"task_embedding": (
all_task_embeddings,
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
),
}
for cam_key in dataset.meta.camera_keys:
add_features_dict[f"{cam_key}_embedding"] = (
image_embeddings_dict[cam_key],
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
)
print("Adding embeddings to dataset...")
remove_features_list = None
if remove_videos:
remove_features_list = dataset.meta.video_keys
output_dataset = modify_features(
dataset=dataset,
add_features=add_features_dict,
remove_features=remove_features_list,
output_dir=output_local_dir,
repo_id=output_repo_id,
)
if remove_videos:
print("Removing video files...")
videos_dir = output_dataset.root / "videos"
if videos_dir.exists():
shutil.rmtree(videos_dir)
print(f"Saved to: {output_dataset.root}")
if push_to_hub:
print(f"Pushing to hub: {output_repo_id}")
output_dataset.push_to_hub(push_videos=not remove_videos)
print("Done!")
def main():
parser = argparse.ArgumentParser(
description="Generate embeddings for LeRobot datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--push-to-hub
# Generate embeddings and remove videos
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--remove-videos \\
--push-to-hub
Available image encoders:
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
Available language encoders:
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
""",
)
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--remove-videos",
action="store_true",
help="Remove video files after generating embeddings",
)
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
parser.add_argument(
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the converted dataset to the hub",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Generate embeddings
generate_embeddings_for_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
remove_videos=args.remove_videos,
local_dir=Path(args.local_dir) if args.local_dir else None,
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()
@@ -0,0 +1,222 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate pre-computed embeddings against on-the-fly computed embeddings.
Usage:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 10
"""
import argparse
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
from lerobot.datasets.generating_embeddings.generate_embeddings import (
get_image_encoder,
get_language_encoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def validate_embeddings(
original_repo_id: str,
embeddings_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
num_samples: int = 10,
device: str = "cuda",
):
"""Validate pre-computed embeddings against on-the-fly embeddings.
Args:
original_repo_id: Original dataset repository ID
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
num_samples: Number of samples to validate
device: Device to use for encoding
"""
# Load both datasets
print("Loading datasets...")
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
# Verify both datasets have the same number of frames
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
f"Frame count mismatch: original={original_dataset.num_frames}, "
f"embeddings={embeddings_dataset.num_frames}"
)
camera_keys = original_dataset.meta.camera_keys
# Check embedding features exist
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
for feat in expected_features:
if feat not in embeddings_dataset.features:
raise ValueError(f"Embedding feature not found: {feat}")
# Select random sample indices
sample_indices = np.random.choice(
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
)
print(f"Validating {len(sample_indices)} samples...")
# Track statistics
task_similarities = []
image_similarities = {cam: [] for cam in camera_keys}
for idx in tqdm(sample_indices, desc="Validating"):
idx = int(idx)
embeddings_item = embeddings_dataset[idx]
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
original_item = original_dataset[idx]
# Get task and compute embedding
task = original_item["task"]
# Clean up task text (same as in generate_embeddings.py)
task_clean = task.strip().capitalize().strip(" .,!?-_")
onthefly_task_emb = language_encoder.encode([task_clean])[0]
# Get images and compute embeddings
onthefly_image_embs = {}
for cam in camera_keys:
img = original_item[cam]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3: # (C, H, W)
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape: {img.shape}")
else:
img_np = np.array(img)
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
# Task embedding comparison
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
task_similarities.append(task_sim)
# Image embedding comparison
for cam in camera_keys:
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
image_similarities[cam].append(img_sim)
# Results
print("\nResults:")
task_sim_threshold = 0.99
img_sim_threshold = 0.99
task_mean_sim = np.mean(task_similarities)
task_pass = task_mean_sim >= task_sim_threshold
print(f" Task: {task_mean_sim:.4f} {'' if task_pass else ''}")
for cam in camera_keys:
cam_mean_sim = np.mean(image_similarities[cam])
cam_pass = cam_mean_sim >= img_sim_threshold
print(f" {cam}: {cam_mean_sim:.4f} {'' if cam_pass else ''}")
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
print()
if task_pass and image_pass:
print("✓ PASSED")
else:
print("✗ FAILED")
def main():
parser = argparse.ArgumentParser(
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
--original-repo-id lerobot/utokyo_xarm_bimanual \\
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--num-samples 20
""",
)
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
parser.add_argument(
"--embeddings-repo-id",
type=str,
required=True,
help="Dataset with pre-computed embeddings repository ID",
)
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of samples to validate (default: 10)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Validate embeddings
validate_embeddings(
original_repo_id=args.original_repo_id,
embeddings_repo_id=args.embeddings_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
num_samples=args.num_samples,
device=args.device,
)
if __name__ == "__main__":
main()
+47 -38
View File
@@ -22,13 +22,11 @@ from pathlib import Path
import datasets
import numpy as np
import os
import packaging.version
import pandas as pd
import PIL.Image
import pyarrow as pa
import pyarrow.parquet as pq
from concurrent.futures import ProcessPoolExecutor
import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
@@ -432,9 +430,7 @@ class LeRobotDatasetMetadata:
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(
video_key=video_key, chunk_index=0, file_index=0
)
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
@@ -716,6 +712,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
# Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
# Build a mapping: absolute_index -> relative_index_in_filtered_dataset
self._absolute_to_relative_idx = None
if self.episodes is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -834,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -851,10 +856,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Determine requested episodes
if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
# Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
@@ -936,7 +939,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -944,11 +951,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
"""
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
Args:
query_indices: Dict mapping keys to index lists to retrieve
Returns:
Dict with stacked tensors of queried data (video keys excluded)
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
# Map absolute indices to relative indices if needed
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -1151,9 +1179,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
use_batched_encoding = self.batch_encoding_size > 1
if has_video_keys and not use_batched_encoding:
video_paths = self._encode_multiple_temporary_episode_videos(self.meta.video_keys, episode_index)
for (video_key, video_path) in zip(self.meta.video_keys, video_paths):
ep_metadata.update(self._save_episode_video(video_key, episode_index, video_path))
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
@@ -1318,12 +1345,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
return metadata
def _save_episode_video(self, video_key: str, episode_index: int, video_path: str | Path | None = None) -> dict:
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video
if video_path is None:
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
else:
ep_path = video_path
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
@@ -1447,22 +1471,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
shutil.rmtree(img_dir)
return temp_path
def _encode_multiple_temporary_episode_videos(self, video_keys, episode_index):
temp_paths = []
img_dirs = []
for video_key in video_keys:
temp_paths.append(Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4")
img_dirs.append(self._get_image_file_dir(episode_index, video_key))
fps = [self.fps]*len(video_keys)
with ProcessPoolExecutor(max_workers=len(video_keys)) as executor:
executor.map(encode_video_frames,img_dirs,temp_paths,fps)
for img_dir in img_dirs:
shutil.rmtree(img_dir)
return temp_paths
@classmethod
def create(
cls,
@@ -1507,6 +1515,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj._absolute_to_relative_idx = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
+151
View File
@@ -0,0 +1,151 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM Temporal Sampler for reward model training.
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
- 1 initial frame + 4 frames before + current + 3 frames after
Boundary handling: clamp to first/last frame when indices go out of bounds.
This enables truly uniform sampling across entire episodes.
"""
import logging
from typing import Iterator, Optional
import numpy as np
import torch
from torch.utils.data import Sampler
import random
class SARMTemporalSampler(Sampler):
"""
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
SARM uses 9 frames per sample:
- Frame 0: Initial frame of the episode (always frame 0)
- Frames 1-8: Symmetric context around current frame
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
This enables truly uniform sampling across entire episodes.
Args:
dataset_from_index: Start indices of episodes (global dataset indices)
dataset_to_index: End indices of episodes (global dataset indices)
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
shuffle: Whether to shuffle sampling order
seed: Random seed for reproducibility
samples_per_epoch: Number of samples per epoch (default: 6400)
min_episode_length: Minimum episode length to include (default: 1)
"""
def __init__(
self,
dataset_from_index: np.ndarray,
dataset_to_index: np.ndarray,
frame_gap: int = 30,
shuffle: bool = True,
seed: Optional[int] = None,
samples_per_epoch: int = 6400,
min_episode_length: int = 1,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.frame_gap = frame_gap
self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch
self.min_episode_length = min_episode_length
if seed is not None:
self.seed = seed
random.seed(seed)
np.random.seed(seed)
self.generator = torch.Generator().manual_seed(seed)
else:
self.generator = torch.Generator()
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
self._compute_valid_positions()
logging.info(
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{len(self.all_valid_positions)} positions (uniform sampling), "
f"{self.samples_per_epoch} samples per epoch, "
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
)
def _compute_valid_positions(self):
"""Compute valid episodes and ALL sampling positions for uniform sampling.
With symmetric bidirectional sampling, we can sample from ANY frame:
- Early frames: backward indices clamp to first frame
- Late frames: forward indices clamp to last frame
"""
self.valid_episodes = []
self.all_valid_positions = []
for ep_idx in range(len(self.dataset_from_index)):
ep_start = self.dataset_from_index[ep_idx]
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
# Include all episodes with at least min_episode_length frames
if episode_length >= self.min_episode_length:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Include ALL positions in the episode (truly uniform sampling)
for pos in range(ep_start, ep_end):
self.all_valid_positions.append(pos)
self.valid_episodes = np.array(self.valid_episodes)
self.all_valid_positions = np.array(self.all_valid_positions)
if len(self.all_valid_positions) == 0:
raise ValueError(
f"No valid sampling positions found! "
f"Check that episodes have at least {self.min_episode_length} frames."
)
def __len__(self) -> int:
return self.samples_per_epoch
def __iter__(self) -> Iterator[int]:
"""
Yields global dataset indices for uniform sampling across episodes.
Each yielded index represents the "current frame" position.
The dataset's observation_delta_indices then handles loading:
- Frame 0: Episode initial frame (via large negative delta clamping)
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
For early frames: backward indices clamp to first frame (progress ~0%)
For late frames: forward indices clamp to last frame (progress ~100%)
"""
if self.shuffle:
# Randomly sample from all valid positions
for _ in range(self.samples_per_epoch):
idx = np.random.randint(0, len(self.all_valid_positions))
yield int(self.all_valid_positions[idx])
else:
# Sequential sampling with wrap-around
for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx])
+18 -4
View File
@@ -28,6 +28,7 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
+1 -4
View File
@@ -310,7 +310,7 @@ def encode_video_frames(
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = True,
overwrite: bool = False,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
# Check encoder availability
@@ -354,9 +354,6 @@ def encode_video_frames(
if crf is not None:
video_options["crf"] = str(crf)
#TEMPORARY FIX
video_options["preset"] = "12"
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
+57 -9
View File
@@ -21,7 +21,22 @@ import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import (
ACTION,
LIBERO_KEY_EEF_MAT,
LIBERO_KEY_EEF_POS,
LIBERO_KEY_EEF_QUAT,
LIBERO_KEY_GRIPPER_QPOS,
LIBERO_KEY_GRIPPER_QVEL,
LIBERO_KEY_JOINTS_POS,
LIBERO_KEY_JOINTS_VEL,
LIBERO_KEY_PIXELS_AGENTVIEW,
LIBERO_KEY_PIXELS_EYE_IN_HAND,
OBS_ENV_STATE,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
)
@dataclass
@@ -246,28 +261,61 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features["pixels/agentview_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(4,),
)
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(3, 3),
)
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
+69 -5
View File
@@ -14,11 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -32,16 +37,60 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config.
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
cfg (EnvConfig): the config of the environment to instantiate.
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
use_async_envs: bool = False,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config or Hub reference.
Args:
cfg (EnvConfig | str): Either an `EnvConfig` object describing the environment to build locally,
or a Hugging Face Hub repository identifier (e.g. `"username/repo"`). In the latter case,
the repo must include a Python file (usually `env.py`).
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False must be set to True to import/exec hub `env.py`.
Raises:
ValueError: if n_envs < 1
@@ -54,6 +103,21 @@ def make_env(
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
"""
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
# simplified: only support hub-provided `make_env`
if isinstance(cfg, str):
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False
repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
# import and surface clear import errors
module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
# normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result)
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
+69 -21
View File
@@ -28,7 +28,6 @@ import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.transform_utils import quat2axisangle
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
@@ -175,11 +174,36 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(
low=AGENT_POS_LOW,
high=AGENT_POS_HIGH,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
"robot_state": spaces.Dict(
{
"eef": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
"quat": spaces.Box(
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
),
"mat": spaces.Box(
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
),
}
),
"gripper": spaces.Dict(
{
"qpos": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
"qvel": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
}
),
"joints": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
}
),
}
),
}
)
@@ -191,6 +215,7 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -212,23 +237,48 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
eef_pos = raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
joint_pos = raw_obs.get("robot0_joint_pos")
joint_vel = raw_obs.get("robot0_joint_vel")
obs = {
"pixels": images,
"robot_state": {
"eef": {
"pos": eef_pos, # (3,)
"quat": eef_quat, # (4,)
"mat": eef_mat, # (3, 3)
},
"gripper": {
"qpos": gripper_qpos, # (2,)
"qvel": gripper_qvel, # (2,)
},
"joints": {
"pos": joint_pos, # (7,)
"vel": joint_vel, # (7,)
},
},
}
if self.obs_type == "pixels":
return {"pixels": images.copy()}
if self.obs_type == "pixels_agent_pos":
return {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
# Validate required fields are present
if eef_pos is None or eef_quat is None or gripper_qpos is None:
raise ValueError(
f"Missing required robot state fields in raw observation. "
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
f"gripper_qpos={gripper_qpos is not None}"
)
return obs
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -355,12 +405,10 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
+152 -6
View File
@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import os
import warnings
from collections.abc import Mapping, Sequence
from functools import singledispatch
@@ -22,14 +24,27 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
def _convert_nested_dict(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _convert_nested_dict(v)
elif isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
else:
result[k] = v
return result
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -75,12 +90,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "agent_pos" in observations:
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations
@@ -195,3 +212,132 @@ def _(envs: Sequence) -> None:
@close_envs.register
def _(env: gym.Env) -> None:
_close_single_env(env)
# helper to safely load a python file as a module
def _load_module_from_path(path: str, module_name: str | None = None):
module_name = module_name or f"hub_env_{os.path.basename(path).replace('.', '_')}"
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Could not load module spec for {module_name} from {path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
# helper to parse hub string (supports "user/repo", "user/repo@rev", optional path)
# examples:
# "user/repo" -> will look for env.py at repo root
# "user/repo@main:envs/my_env.py" -> explicit revision and path
def _parse_hub_url(hub_uri: str):
# very small parser: [repo_id][@revision][:path]
# repo_id is required (user/repo or org/repo)
revision = None
file_path = "env.py"
if "@" in hub_uri:
repo_and_rev, *rest = hub_uri.split(":", 1)
repo_id, rev = repo_and_rev.split("@", 1)
revision = rev
if rest:
file_path = rest[0]
else:
repo_id, *rest = hub_uri.split(":", 1)
if rest:
file_path = rest[0]
return repo_id, revision, file_path
def _download_hub_file(
cfg_str: str,
trust_remote_code: bool,
hub_cache_dir: str | None,
) -> tuple[str, str, str, str]:
"""
Parse `cfg_str` (hub URL), enforce `trust_remote_code`, and return
(repo_id, file_path, local_file, revision).
"""
if not trust_remote_code:
raise RuntimeError(
f"Refusing to execute remote code from the Hub for '{cfg_str}'. "
"Executing hub env modules runs arbitrary Python code from third-party repositories. "
"If you trust this repo and understand the risks, call `make_env(..., trust_remote_code=True)` "
"and prefer pinning to a specific revision: 'user/repo@<commit-hash>:env.py'."
)
repo_id, revision, file_path = _parse_hub_url(cfg_str)
try:
local_file = hf_hub_download(
repo_id=repo_id, filename=file_path, revision=revision, cache_dir=hub_cache_dir
)
except Exception as e:
# fallback to snapshot download
snapshot_dir = snapshot_download(repo_id=repo_id, revision=revision, cache_dir=hub_cache_dir)
local_file = os.path.join(snapshot_dir, file_path)
if not os.path.exists(local_file):
raise FileNotFoundError(
f"Could not find {file_path} in repository {repo_id}@{revision or 'main'}"
) from e
return repo_id, file_path, local_file, revision
def _import_hub_module(local_file: str, repo_id: str) -> Any:
"""
Import the downloaded file as a module and surface helpful import error messages.
"""
module_name = f"hub_env_{repo_id.replace('/', '_')}"
try:
module = _load_module_from_path(local_file, module_name=module_name)
except ModuleNotFoundError as e:
missing = getattr(e, "name", None) or str(e)
raise ModuleNotFoundError(
f"Hub env '{repo_id}:{os.path.basename(local_file)}' failed to import because the dependency "
f"'{missing}' is not installed locally.\n\n"
) from e
except ImportError as e:
raise ImportError(
f"Failed to load hub env module '{repo_id}:{os.path.basename(local_file)}'. Import error: {e}\n\n"
) from e
return module
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
"""
Ensure module exposes make_env and call it.
"""
if not hasattr(module, "make_env"):
raise AttributeError(
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
)
entry_fn = module.make_env
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""
Normalize possible return types from hub `make_env` into the mapping:
{ suite_name: { task_id: vector_env } }
Accepts:
- dict (assumed already correct)
- gym.vector.VectorEnv
- gym.Env (will be wrapped into SyncVectorEnv)
"""
if isinstance(result, dict):
return result
# VectorEnv: use its spec.id if available
if isinstance(result, gym.vector.VectorEnv):
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
return {suite_name: {0: result}}
# Single Env: wrap into SyncVectorEnv
if isinstance(result, gym.Env):
vec = gym.vector.SyncVectorEnv([lambda: result])
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
return {suite_name: {0: vec}}
raise ValueError(
"Hub `make_env` must return either a mapping {suite: {task_id: vec_env}}, "
"a gym.vector.VectorEnv, or a single gym.Env."
)
+1 -8
View File
@@ -14,11 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .motors_bus import (
Motor,
MotorCalibration,
MotorNormMode,
MotorsBus, # Backward compatibility (alias for SerialMotorsBus)
MotorsBusBase,
SerialMotorsBus,
)
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
-18
View File
@@ -1,18 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .damiao import DamiaoMotorsBus
from .tables import *
-905
View File
@@ -1,905 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(pepijn): add license of: https://github.com/cmjang/DM_Control_Python?tab=MIT-1-ov-file#readme
import logging
import time
from contextlib import contextmanager
from copy import deepcopy
from functools import cached_property
from typing import Dict, List, Optional, Tuple, Union
import can
import numpy as np
from lerobot.motors import Motor, MotorCalibration, MotorNormMode, MotorsBusBase
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
from .tables import (
AVAILABLE_BAUDRATES,
CAN_CMD_DISABLE,
CAN_CMD_ENABLE,
CAN_CMD_REFRESH,
CAN_CMD_SET_ZERO,
CAN_PARAM_ID,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
MotorType,
)
logger = logging.getLogger(__name__)
NameOrID = Union[str, int]
Value = Union[int, float]
class DamiaoMotorsBus(MotorsBusBase):
"""
The Damiao implementation for a MotorsBus using CAN bus communication.
This class uses python-can for CAN bus communication with Damiao motors.
For more info, see:
- python-can documentation: https://python-can.readthedocs.io/en/stable/
- Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/
- DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python
"""
# CAN-specific settings
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_baudrate = DEFAULT_BAUDRATE
default_timeout = DEFAULT_TIMEOUT_MS
# Motor configuration
model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalized_data = deepcopy(NORMALIZED_DATA)
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
can_interface: str = "auto",
use_can_fd: bool = True,
bitrate: int = 1000000,
data_bitrate: int | None = 5000000,
):
"""
Initialize the Damiao motors bus.
Args:
port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS)
motors: Dictionary mapping motor names to Motor objects
calibration: Optional calibration data
can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial)
use_can_fd: Whether to use CAN FD mode (default: True for OpenArms)
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
self.use_can_fd = use_can_fd
self.bitrate = bitrate
self.data_bitrate = data_bitrate
self.canbus = None
self._is_connected = False
# Map motor names to CAN IDs
self._motor_can_ids = {}
self._recv_id_to_motor = {}
# Store motor types and recv IDs
self._motor_types = {}
for name, motor in self.motors.items():
if hasattr(motor, "motor_type"):
self._motor_types[name] = motor.motor_type
else:
# Default to DM4310 if not specified
self._motor_types[name] = MotorType.DM4310
# Map recv_id to motor name for filtering responses
if hasattr(motor, "recv_id"):
self._recv_id_to_motor[motor.recv_id] = name
@property
def is_connected(self) -> bool:
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
if self.can_interface == "auto":
if self.port.startswith("/dev/"):
# Serial device (macOS/Windows)
self.can_interface = "slcan"
logger.info(f"Auto-detected slcan interface for port {self.port}")
else:
# Network interface (Linux)
self.can_interface = "socketcan"
logger.info(f"Auto-detected socketcan interface for port {self.port}")
# Connect to CAN bus
if self.can_interface == "socketcan":
# Linux SocketCAN with CAN FD support
if self.use_can_fd and self.data_bitrate is not None:
self.canbus = can.interface.Bus(
channel=self.port,
interface="socketcan",
bitrate=self.bitrate,
data_bitrate=self.data_bitrate,
fd=True
)
logger.info(f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})")
else:
self.canbus = can.interface.Bus(
channel=self.port,
interface="socketcan",
bitrate=self.bitrate
)
logger.info(f"Connected to {self.port} with CAN 2.0 (bitrate={self.bitrate})")
elif self.can_interface == "slcan":
# Serial Line CAN (macOS, Windows, or USB adapters)
# Note: SLCAN typically doesn't support CAN FD
self.canbus = can.interface.Bus(
channel=self.port,
interface="slcan",
bitrate=self.bitrate
)
logger.info(f"Connected to {self.port} with SLCAN (bitrate={self.bitrate})")
else:
# Generic interface (vector, pcan, etc.)
if self.use_can_fd and self.data_bitrate is not None:
self.canbus = can.interface.Bus(
channel=self.port,
interface=self.can_interface,
bitrate=self.bitrate,
data_bitrate=self.data_bitrate,
fd=True
)
else:
self.canbus = can.interface.Bus(
channel=self.port,
interface=self.can_interface,
bitrate=self.bitrate
)
self._is_connected = True
if handshake:
self._handshake()
logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.")
except Exception as e:
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}")
def _handshake(self) -> None:
"""Verify all motors are present by refreshing their status."""
for motor_name in self.motors:
self._refresh_motor(motor_name)
time.sleep(0.01) # Small delay between motors
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected."
)
if disable_torque:
try:
self.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
if self.canbus:
self.canbus.shutdown()
self.canbus = None
self._is_connected = False
logger.debug(f"{self.__class__.__name__} disconnected.")
def configure_motors(self) -> None:
"""Configure all motors with default settings."""
# Damiao motors don't require much configuration in MIT mode
# Just ensure they're enabled
for motor in self.motors:
self._enable_motor(motor)
time.sleep(0.01)
def _enable_motor(self, motor: NameOrID) -> None:
"""Enable a single motor."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [CAN_CMD_ENABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
self._recv_motor_response(expected_recv_id=recv_id)
def _disable_motor(self, motor: NameOrID) -> None:
"""Disable a single motor."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [CAN_CMD_DISABLE]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
self._recv_motor_response(expected_recv_id=recv_id)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
motors = self._get_motors_list(motors)
for motor in motors:
for _ in range(num_retry + 1):
try:
self._enable_motor(motor)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(0.01)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
motors = self._get_motors_list(motors)
for motor in motors:
for _ in range(num_retry + 1):
try:
self._disable_motor(motor)
break
except Exception as e:
if _ == num_retry:
raise e
time.sleep(0.01)
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
"""
Context manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
Examples:
>>> with bus.torque_disabled():
... # Safe operations here with torque disabled
... pass
"""
self.disable_torque(motors)
try:
yield
finally:
self.enable_torque(motors)
def set_zero_position(self, motors: str | list[str] | None = None) -> None:
"""Set current position as zero for selected motors."""
motors = self._get_motors_list(motors)
for motor in motors:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [CAN_CMD_SET_ZERO]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
self._recv_motor_response(expected_recv_id=recv_id)
time.sleep(0.01)
def _refresh_motor(self, motor: NameOrID) -> Optional[can.Message]:
"""Refresh motor status and return the response."""
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
def _recv_motor_response(self, expected_recv_id: Optional[int] = None, timeout: float = 0.001) -> Optional[can.Message]:
"""
Receive a response from a motor.
Args:
expected_recv_id: If provided, only return messages from this CAN ID
timeout: Timeout in seconds (default: 1ms for high-speed operation)
Returns:
CAN message if received, None otherwise
"""
try:
start_time = time.time()
messages_seen = []
while time.time() - start_time < timeout:
msg = self.canbus.recv(timeout=0.0001) # 100us timeout for fast polling
if msg:
messages_seen.append(f"0x{msg.arbitration_id:02X}")
# If no filter specified, return any message
if expected_recv_id is None:
return msg
# Otherwise, only return if it matches the expected recv_id
if msg.arbitration_id == expected_recv_id:
return msg
else:
logger.debug(f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}")
# Only log warnings if we're in debug mode to reduce overhead
if logger.isEnabledFor(logging.DEBUG):
if messages_seen:
logger.debug(f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}")
else:
logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})")
except Exception as e:
logger.debug(f"Failed to receive CAN message: {e}")
return None
def _recv_all_responses(self, expected_recv_ids: list[int], timeout: float = 0.002) -> dict[int, can.Message]:
"""
Efficiently receive responses from multiple motors at once.
Uses the OpenArms pattern: collect all available messages within timeout.
Args:
expected_recv_ids: List of CAN IDs we expect responses from
timeout: Total timeout in seconds (default: 2ms)
Returns:
Dictionary mapping recv_id to CAN message
"""
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
msg = self.canbus.recv(timeout=0.0002) # 200us poll timeout (increased from 100us for better reliability)
if msg and msg.arbitration_id in expected_set:
responses[msg.arbitration_id] = msg
if len(responses) == len(expected_recv_ids):
break # Got all responses, exit early
except Exception as e:
logger.debug(f"Error receiving responses: {e}")
return responses
def _mit_control(
self,
motor: NameOrID,
kp: float,
kd: float,
position_degrees: float,
velocity_deg_per_sec: float,
torque: float,
) -> None:
"""
Send MIT control command to a motor.
Args:
motor: Motor name or ID
kp: Position gain
kd: Velocity gain
position_degrees: Target position (degrees)
velocity_deg_per_sec: Target velocity (degrees/s)
torque: Target torque (N·m)
"""
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
# Convert degrees to radians for motor control
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, 0, 500, 12)
kd_uint = self._float_to_uint(kd, 0, 5, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
recv_id = self._get_motor_recv_id(motor)
self._recv_motor_response(expected_recv_id=recv_id)
def _mit_control_batch(
self,
commands: Dict[NameOrID, Tuple[float, float, float, float, float]],
) -> None:
"""
Send MIT control commands to multiple motors in batch (optimized).
Sends all commands first, then collects responses. Much faster than sequential.
Args:
commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque)
Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...}
"""
if not commands:
return
expected_recv_ids = []
# Step 1: Send all MIT control commands (no waiting)
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
# Convert degrees to radians
position_rad = np.radians(position_degrees)
velocity_rad_per_sec = np.radians(velocity_deg_per_sec)
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Encode parameters
kp_uint = self._float_to_uint(kp, 0, 500, 12)
kd_uint = self._float_to_uint(kd, 0, 5, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12)
tau_uint = self._float_to_uint(torque, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
# Send command
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
# Track expected response
recv_id = self._get_motor_recv_id(motor)
expected_recv_ids.append(recv_id)
# Step 2: Collect all responses at once
self._recv_all_responses(expected_recv_ids, timeout=0.002)
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
x = max(x_min, min(x_max, x)) # Clamp to range
span = x_max - x_min
data_norm = (x - x_min) / span
return int(data_norm * ((1 << bits) - 1))
def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float:
"""Convert unsigned integer from CAN to float."""
span = x_max - x_min
data_norm = float(x) / ((1 << bits) - 1)
return data_norm * span + x_min
def _decode_motor_state(self, data: bytes, motor_type: MotorType) -> Tuple[float, float, float, int, int]:
"""
Decode motor state from CAN data.
Returns:
Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos, temp_rotor)
"""
if len(data) < 8:
raise ValueError("Invalid motor state data")
# Extract encoded values
q_uint = (data[1] << 8) | data[2]
dq_uint = (data[3] << 4) | (data[4] >> 4)
tau_uint = ((data[4] & 0x0F) << 8) | data[5]
t_mos = data[6]
t_rotor = data[7]
# Get motor limits
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
# Decode to physical values (radians)
position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16)
velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12)
torque = self._uint_to_float(tau_uint, -tmax, tmax, 12)
# Convert to degrees
position_degrees = np.degrees(position_rad)
velocity_deg_per_sec = np.degrees(velocity_rad_per_sec)
return position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor
def read(
self,
data_name: str,
motor: str,
*,
normalize: bool = True,
num_retry: int = 0,
) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
if msg is None:
motor_id = self._get_motor_id(motor)
recv_id = self._get_motor_recv_id(motor)
raise ConnectionError(
f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). "
f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, "
f"3) Motor IDs are configured correctly using Damiao Debugging Tools"
)
motor_type = self._motor_types.get(motor, MotorType.DM4310)
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
# Return requested data (already in degrees for position/velocity)
if data_name == "Present_Position":
value = position_degrees
elif data_name == "Present_Velocity":
value = velocity_deg_per_sec
elif data_name == "Present_Torque":
value = torque
elif data_name == "Temperature_MOS":
value = t_mos
elif data_name == "Temperature_Rotor":
value = t_rotor
else:
raise ValueError(f"Unknown data_name: {data_name}")
# For Damiao, positions are always in degrees, no normalization needed
# We keep the normalize parameter for compatibility but don't use it
return value
def write(
self,
data_name: str,
motor: str,
value: Value,
*,
normalize: bool = True,
num_retry: int = 0,
) -> None:
"""Write a value to a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Value is expected to be in degrees for positions
if data_name == "Goal_Position":
# Use MIT control with position in degrees
self._mit_control(motor, 10.0, 0.5, value, 0, 0)
else:
raise ValueError(f"Writing {data_name} not supported in MIT mode")
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
) -> Dict[str, Value]:
"""
Read the same value from multiple motors simultaneously.
Uses batched operations: sends all refresh commands, then collects all responses.
This is MUCH faster than sequential reads (OpenArms pattern).
"""
motors = self._get_motors_list(motors)
result = {}
# Step 1: Send refresh commands to ALL motors first (no waiting)
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
# Step 2: Collect all responses at once (batch receive)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=0.01) # 10ms total timeout
# Step 3: Parse responses
for motor in motors:
try:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg is None:
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
result[motor] = 0.0
continue
motor_type = self._motor_types.get(motor, MotorType.DM4310)
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
# Return requested data
if data_name == "Present_Position":
value = position_degrees
elif data_name == "Present_Velocity":
value = velocity_deg_per_sec
elif data_name == "Present_Torque":
value = torque
elif data_name == "Temperature_MOS":
value = t_mos
elif data_name == "Temperature_Rotor":
value = t_rotor
else:
raise ValueError(f"Unknown data_name: {data_name}")
result[motor] = value
except Exception as e:
logger.warning(f"Failed to read {data_name} from {motor}: {e}")
result[motor] = 0.0
return result
def sync_read_all_states(
self,
motors: str | list[str] | None = None,
*,
num_retry: int = 0,
) -> Dict[str, Dict[str, Value]]:
"""
Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle.
This is 3x faster than calling sync_read() three times separately.
Returns:
Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque'
Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...}
"""
motors = self._get_motors_list(motors)
result = {}
# Step 1: Send refresh commands to ALL motors first (with small delays to reduce bus congestion)
for motor in motors:
motor_id = self._get_motor_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
self.canbus.send(msg)
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
# Step 2: Collect all responses at once (batch receive)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
# Step 3: Parse responses and extract ALL state values
for motor in motors:
try:
recv_id = self._get_motor_recv_id(motor)
msg = responses.get(recv_id)
if msg is None:
logger.warning(f"No response from motor '{motor}' (recv ID: 0x{recv_id:02X})")
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
continue
motor_type = self._motor_types.get(motor, MotorType.DM4310)
position_degrees, velocity_deg_per_sec, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type)
# Return all state values in one dict
result[motor] = {
"position": position_degrees,
"velocity": velocity_deg_per_sec,
"torque": torque,
"temp_mos": t_mos,
"temp_rotor": t_rotor,
}
except Exception as e:
logger.warning(f"Failed to read state from {motor}: {e}")
result[motor] = {"position": 0.0, "velocity": 0.0, "torque": 0.0}
return result
def sync_write(
self,
data_name: str,
values: Dict[str, Value],
*,
normalize: bool = True,
num_retry: int = 0,
) -> None:
"""
Write different values to multiple motors simultaneously. Positions are always in degrees.
Uses batched operations: sends all commands first, then collects responses (OpenArms pattern).
"""
if data_name == "Goal_Position":
# Step 1: Send all MIT control commands first (no waiting)
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types.get(motor_name, MotorType.DM4310)
# Convert degrees to radians
position_rad = np.radians(value_degrees)
# Default gains for position control
kp, kd = 10.0, 0.5
# Get motor limits and encode parameters
pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type]
kp_uint = self._float_to_uint(kp, 0, 500, 12)
kd_uint = self._float_to_uint(kd, 0, 5, 12)
q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16)
dq_uint = self._float_to_uint(0, -vmax, vmax, 12)
tau_uint = self._float_to_uint(0, -tmax, tmax, 12)
# Pack data
data = [0] * 8
data[0] = (q_uint >> 8) & 0xFF
data[1] = q_uint & 0xFF
data[2] = dq_uint >> 4
data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF)
data[4] = kp_uint & 0xFF
data[5] = kd_uint >> 4
data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF)
data[7] = tau_uint & 0xFF
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self.canbus.send(msg)
time.sleep(0.0001) # 100us delay between commands to reduce bus congestion
# Step 2: Collect all responses at once
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in values.keys()]
self._recv_all_responses(expected_recv_ids, timeout=0.015) # 15ms timeout (increased for reliability)
else:
# Fall back to individual writes for other data types
for motor, value in values.items():
self.write(data_name, motor, value, normalize=normalize, num_retry=num_retry)
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration data from motors."""
# Damiao motors don't store calibration internally
# Return existing calibration or empty dict
return self.calibration if self.calibration else {}
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration data to motors."""
# Damiao motors don't store calibration internally
# Just cache it in memory
if cache:
self.calibration = calibration_dict
def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
Move the joints by hand (with torque disabled) while the method streams live positions.
Press Enter to finish.
"""
if motors is None:
motors = list(self.motors.keys())
elif isinstance(motors, (str, int)):
motors = [motors]
# Disable torque for manual movement
self.disable_torque(motors)
time.sleep(0.1)
# Get initial positions (already in degrees)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
print("\nMove joints through their full range of motion. Press ENTER when done.")
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motors, normalize=False)
for motor in motors:
if motor in positions:
mins[motor] = min(positions[motor], mins.get(motor, positions[motor]))
maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor]))
if display_values:
print("\n" + "=" * 50)
print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}")
print("-" * 50)
for motor in motors:
if motor in positions:
print(f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}")
if enter_pressed():
user_pressed_enter = True
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 4)
time.sleep(0.05)
# Re-enable torque
self.enable_torque(motors)
# Validate ranges
for motor in motors:
if motor in mins and motor in maxes:
if abs(maxes[motor] - mins[motor]) < 5.0: # At least 5 degrees of range
raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)")
return mins, maxes
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
"""Convert motor specification to list of motor names."""
if motors is None:
return list(self.motors.keys())
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors
else:
raise TypeError(f"Invalid motors type: {type(motors)}")
def _get_motor_id(self, motor: NameOrID) -> int:
"""Get CAN ID for a motor."""
if isinstance(motor, str):
if motor in self.motors:
return self.motors[motor].id
else:
raise ValueError(f"Unknown motor: {motor}")
else:
return motor
def _get_motor_name(self, motor: NameOrID) -> str:
"""Get motor name from name or ID."""
if isinstance(motor, str):
return motor
else:
for name, m in self.motors.items():
if m.id == motor:
return name
raise ValueError(f"Unknown motor ID: {motor}")
def _get_motor_recv_id(self, motor: NameOrID) -> Optional[int]:
"""Get motor recv_id from name or ID."""
motor_name = self._get_motor_name(motor)
motor_obj = self.motors.get(motor_name)
if motor_obj and hasattr(motor_obj, "recv_id"):
return motor_obj.recv_id
return None
@cached_property
def is_calibrated(self) -> bool:
"""Check if motors are calibrated."""
return bool(self.calibration)
-209
View File
@@ -1,209 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
from typing import Dict, List, Tuple
# Motor type definitions
class MotorType(IntEnum):
DM3507 = 0
DM4310 = 1
DM4310_48V = 2
DM4340 = 3
DM4340_48V = 4
DM6006 = 5
DM8006 = 6
DM8009 = 7
DM10010L = 8
DM10010 = 9
DMH3510 = 10
DMH6215 = 11
DMG6220 = 12
# Control modes
class ControlMode(IntEnum):
MIT = 1
POS_VEL = 2
VEL = 3
TORQUE_POS = 4
# Motor variable IDs (RID)
class MotorVariable(IntEnum):
UV_VALUE = 0
KT_VALUE = 1
OT_VALUE = 2
OC_VALUE = 3
ACC = 4
DEC = 5
MAX_SPD = 6
MST_ID = 7
ESC_ID = 8
TIMEOUT = 9
CTRL_MODE = 10
DAMP = 11
INERTIA = 12
HW_VER = 13
SW_VER = 14
SN = 15
NPP = 16
RS = 17
LS = 18
FLUX = 19
GR = 20
PMAX = 21
VMAX = 22
TMAX = 23
I_BW = 24
KP_ASR = 25
KI_ASR = 26
KP_APR = 27
KI_APR = 28
OV_VALUE = 29
GREF = 30
DETA = 31
V_BW = 32
IQ_C1 = 33
VL_C1 = 34
CAN_BR = 35
SUB_VER = 36
U_OFF = 50
V_OFF = 51
K1 = 52
K2 = 53
M_OFF = 54
DIR = 55
P_M = 80
XOUT = 81
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS = {
MotorType.DM3507: (12.5, 30, 10),
MotorType.DM4310: (12.5, 30, 10),
MotorType.DM4310_48V: (12.5, 50, 10),
MotorType.DM4340: (12.5, 8, 28),
MotorType.DM4340_48V: (12.5, 10, 28),
MotorType.DM6006: (12.5, 45, 20),
MotorType.DM8006: (12.5, 45, 40),
MotorType.DM8009: (12.5, 45, 54),
MotorType.DM10010L: (12.5, 25, 200),
MotorType.DM10010: (12.5, 20, 200),
MotorType.DMH3510: (12.5, 280, 1),
MotorType.DMH6215: (12.5, 45, 10),
MotorType.DMG6220: (12.5, 45, 10),
}
# Motor model names
MODEL_NAMES = {
MotorType.DM3507: "dm3507",
MotorType.DM4310: "dm4310",
MotorType.DM4310_48V: "dm4310_48v",
MotorType.DM4340: "dm4340",
MotorType.DM4340_48V: "dm4340_48v",
MotorType.DM6006: "dm6006",
MotorType.DM8006: "dm8006",
MotorType.DM8009: "dm8009",
MotorType.DM10010L: "dm10010l",
MotorType.DM10010: "dm10010",
MotorType.DMH3510: "dmh3510",
MotorType.DMH6215: "dmh6215",
MotorType.DMG6220: "dmg6220",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"dm3507": 65536,
"dm4310": 65536,
"dm4310_48v": 65536,
"dm4340": 65536,
"dm4340_48v": 65536,
"dm6006": 65536,
"dm8006": 65536,
"dm8009": 65536,
"dm10010l": 65536,
"dm10010": 65536,
"dmh3510": 65536,
"dmh6215": 65536,
"dmg6220": 65536,
}
# CAN baudrates supported by Damiao motors
AVAILABLE_BAUDRATES = [
125000, # 0: 125 kbps
200000, # 1: 200 kbps
250000, # 2: 250 kbps
500000, # 3: 500 kbps
1000000, # 4: 1 mbps (default for OpenArms)
2000000, # 5: 2 mbps
2500000, # 6: 2.5 mbps
3200000, # 7: 3.2 mbps
4000000, # 8: 4 mbps
5000000, # 9: 5 mbps
]
DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 1000
# Data that should be normalized
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
# OpenArms specific configurations
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms has 7 DOF per arm (14 total for dual arm)
OPENARMS_ARM_MOTOR_IDS = {
"joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan
"joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift
"joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex
"joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex
"joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll
"joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch
"joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation
}
OPENARMS_GRIPPER_MOTOR_IDS = {
"gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper
}
# Default motor types for OpenArms
OPENARMS_DEFAULT_MOTOR_TYPES = {
"joint_1": MotorType.DM8009, # Shoulder pan - high torque
"joint_2": MotorType.DM8009, # Shoulder lift - high torque
"joint_3": MotorType.DM4340, # Shoulder rotation
"joint_4": MotorType.DM4340, # Elbow flex
"joint_5": MotorType.DM4310, # Wrist roll
"joint_6": MotorType.DM4310, # Wrist pitch
"joint_7": MotorType.DM4310, # Wrist rotation
"gripper": MotorType.DM4310, # Gripper
}
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_REFRESH = 0xCC
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF
+2 -2
View File
@@ -24,7 +24,7 @@ from enum import Enum
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
@@ -100,7 +100,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]:
return data
class DynamixelMotorsBus(SerialMotorsBus):
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
+3 -3
View File
@@ -19,7 +19,7 @@ from pprint import pformat
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
@@ -96,7 +96,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(SerialMotorsBus):
class FeetechMotorsBus(MotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
@@ -165,7 +165,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _handshake(self) -> None:
self._assert_motors_exist()
#self._assert_same_firmware()
self._assert_same_firmware()
def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
if self.protocol_version == 0:
+4 -96
View File
@@ -19,8 +19,6 @@
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
from __future__ import annotations
import abc
import logging
from contextlib import contextmanager
@@ -43,92 +41,6 @@ Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
class MotorsBusBase(abc.ABC):
"""
Base class for all motor bus implementations.
This is a minimal interface that all motor buses must implement, regardless of their
communication protocol (serial, CAN, etc.).
"""
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
@abc.abstractmethod
def connect(self, handshake: bool = True) -> None:
"""Establish connection to the motors."""
pass
@abc.abstractmethod
def disconnect(self, disable_torque: bool = True) -> None:
"""Disconnect from the motors."""
pass
@property
@abc.abstractmethod
def is_connected(self) -> bool:
"""Check if connected to the motors."""
pass
@abc.abstractmethod
def read(self, data_name: str, motor: str, *, normalize: bool = True, num_retry: int = 0) -> Value:
"""Read a value from a single motor."""
pass
@abc.abstractmethod
def write(
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
) -> None:
"""Write a value to a single motor."""
pass
@abc.abstractmethod
def sync_read(
self, data_name: str, motors: str | list[str] | None = None, *, normalize: bool = True
) -> dict[str, Value]:
"""Read a value from multiple motors."""
pass
@abc.abstractmethod
def sync_write(
self,
data_name: str,
values: Value | dict[str, Value],
motors: str | list[str] | None = None,
*,
normalize: bool = True,
) -> None:
"""Write values to multiple motors."""
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors."""
pass
@abc.abstractmethod
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
"""Disable torque on selected motors."""
pass
@abc.abstractmethod
def read_calibration(self) -> dict[str, MotorCalibration]:
"""Read calibration parameters from the motors."""
pass
@abc.abstractmethod
def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
"""Write calibration parameters to the motors."""
pass
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
@@ -291,15 +203,15 @@ class GroupSyncWrite(Protocol):
def txPacket(self): ...
class SerialMotorsBus(MotorsBusBase):
class MotorsBus(abc.ABC):
"""
A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication.
A MotorsBus allows to efficiently read and write to the attached motors.
It represents several motors daisy-chained together and connected through a serial port.
There are currently two implementations of this class:
There are currently two implementations of this abstract class:
- DynamixelMotorsBus
- FeetechMotorsBus
This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.).
Note: This class may evolve in the future should we add support for other types of bus.
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
@@ -1300,7 +1212,3 @@ class SerialMotorsBus(MotorsBusBase):
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
# Backward compatibility alias
MotorsBus = SerialMotorsBus
+24 -16
View File
@@ -35,9 +35,11 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
@@ -102,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
@@ -321,6 +327,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
@@ -404,6 +418,13 @@ def make_policy(
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, 'stats'):
kwargs["dataset_stats"] = ds_meta.stats
if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
@@ -420,20 +441,7 @@ def make_policy(
# policy = torch.compile(policy, mode="reduce-overhead")
if not rename_map:
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
provided_features = set(features.keys())
if expected_features and provided_features != expected_features:
missing = expected_features - provided_features
extra = provided_features - expected_features
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
raise ValueError(
f"Feature mismatch between dataset/environment and policy config.\n"
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
f"Please ensure your dataset and policy use consistent feature names.\n"
f"If your dataset uses different observation keys (e.g., cameras named differently), "
f"use the `--rename_map` argument, for example:\n"
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
f'"observation.images.top": "observation.images.camera2"}}\''
)
validate_visual_features_consistency(cfg, features)
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -15
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0 PyTorch model."""
def __init__(self, config: PI0Config):
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI0 model
self.model = PI0Pytorch(config)
self.init_rtc_processor()
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
# Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
# Sample actions using the model (pass through RTC kwargs)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@PreTrainedConfig.register_subclass("pi05")
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
+83 -14
View File
@@ -19,11 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model."""
def __init__(self, config: PI05Config):
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI05 model
self.model = PI05Pytorch(config)
self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]
+38
View File
@@ -0,0 +1,38 @@
# Real-Time Chunking (RTC)
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
---
## Citation
If you use Real-Time Chunking in your work, please cite:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2025realtimeexecutionactionchunking,
title={Real-Time Execution of Action Chunking Flow Policies},
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
year={2025},
eprint={2506.07339},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2506.07339},
}
```
---
## License
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
+219
View File
@@ -0,0 +1,219 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action queue management for Real-Time Chunking (RTC).
This module provides ActionQueue, a thread-safe queue for managing action chunks
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
handling action merging and leftover tracking.
"""
import logging
from threading import Lock
import torch
from torch import Tensor
from lerobot.policies.rtc.configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
class ActionQueue:
"""Thread-safe queue for managing action chunks in real-time control.
This queue handles two types of action sequences:
- Original actions: Used for RTC to compute leftovers from previous chunks
- Processed actions: Post-processed actions ready for robot execution
The queue operates in two modes:
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
Args:
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
Attributes:
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
last_index (int): Current consumption index in the queue.
"""
def __init__(self, cfg: RTCConfig):
"""Initialize the action queue.
Args:
cfg: RTC configuration controlling queue behavior.
"""
self.queue = None # Processed actions for robot rollout
self.original_queue = None # Original actions for RTC
self.lock = Lock()
self.last_index = 0
self.cfg = cfg
def get(self) -> Tensor | None:
"""Get the next action from the queue.
Returns:
Tensor | None: The next action (action_dim,) or None if queue is empty.
Returns a clone to prevent external modifications.
"""
with self.lock:
if self.queue is None or self.last_index >= len(self.queue):
return None
action = self.queue[self.last_index]
self.last_index += 1
return action.clone()
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
These are the unconsumed actions from the current chunk, which will be
used by RTC to compute corrections for the next chunk.
Returns:
Tensor | None: Remaining original actions (remaining_steps, action_dim),
or None if no original queue exists.
"""
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
):
"""Merge new actions into the queue.
This method operates differently based on RTC mode:
- RTC enabled: Replaces the queue, accounting for inference delay
- RTC disabled: Appends to the queue, maintaining continuity
Args:
original_actions: Unprocessed actions from policy (time_steps, action_dim).
processed_actions: Post-processed actions for robot (time_steps, action_dim).
real_delay: Number of time steps of inference delay.
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
return
self._append_actions_queue(original_actions, processed_actions)
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
"""Replace the queue with new actions (RTC mode).
Discards the first `real_delay` actions since they correspond to the time
spent during inference, when the robot was executing previous actions.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
self.last_index = 0
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
"""Append new actions to the queue (non-RTC mode).
Removes already-consumed actions and appends new ones, maintaining
queue continuity without replacement.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
"""
if self.queue is None:
self.original_queue = original_actions.clone()
self.queue = processed_actions.clone()
return
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
self.original_queue = self.original_queue[self.last_index :]
self.queue = torch.cat([self.queue, processed_actions.clone()])
self.queue = self.queue[self.last_index :]
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
number of actions consumed during inference.
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
Based on:
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
"""
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule
@dataclass
class RTCConfig:
"""Configuration for Real Time Chunking (RTC) inference.
RTC improves real-time inference by treating chunk generation as an inpainting problem,
strategically handling overlapping timesteps between action chunks using prefix attention.
"""
# Infrastructure
enabled: bool = False
# Core RTC settings
# Todo change to exp
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
max_guidance_weight: float = 10.0
execution_horizon: int = 10
# Debug settings
debug: bool = False
debug_maxlen: int = 100
def __post_init__(self):
"""Validate RTC configuration parameters."""
if self.max_guidance_weight <= 0:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
+233
View File
@@ -0,0 +1,233 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Debug information handler for Real-Time Chunking (RTC)."""
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
@dataclass
class DebugStep:
"""Container for debug information from a single denoising step.
Attributes:
step_idx (int): Step index/counter.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
time (float | Tensor | None): Time parameter.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
metadata (dict[str, Any]): Additional metadata.
"""
step_idx: int = 0
x_t: Tensor | None = None
v_t: Tensor | None = None
x1_t: Tensor | None = None
correction: Tensor | None = None
err: Tensor | None = None
weights: Tensor | None = None
guidance_weight: float | Tensor | None = None
time: float | Tensor | None = None
inference_delay: int | None = None
execution_horizon: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
"""Convert debug step to dictionary.
Args:
include_tensors (bool): If True, include tensor values. If False, only include
tensor statistics (shape, mean, std, min, max).
Returns:
Dictionary representation of the debug step.
"""
result = {
"step_idx": self.step_idx,
"guidance_weight": (
self.guidance_weight.item()
if isinstance(self.guidance_weight, Tensor)
else self.guidance_weight
),
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
"inference_delay": self.inference_delay,
"execution_horizon": self.execution_horizon,
"metadata": self.metadata.copy(),
}
# Add tensor information
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
for field_name in tensor_fields:
tensor = getattr(self, field_name)
if tensor is not None:
if include_tensors:
result[field_name] = tensor.detach().cpu()
else:
result[f"{field_name}_stats"] = {
"shape": tuple(tensor.shape),
"mean": tensor.mean().item(),
"std": tensor.std().item(),
"min": tensor.min().item(),
"max": tensor.max().item(),
}
return result
class Tracker:
"""Collects and manages debug information for RTC processing.
This tracker stores debug information from recent denoising steps in a dictionary,
using time as the key for efficient lookups and updates.
Args:
enabled (bool): Whether debug collection is enabled.
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
"""
def __init__(self, enabled: bool = False, maxlen: int = 100):
self.enabled = enabled
self._steps = {} if enabled else None # Dictionary with time as key
self._maxlen = maxlen
self._step_counter = 0
def reset(self) -> None:
"""Clear all recorded debug information."""
if self.enabled and self._steps is not None:
self._steps.clear()
self._step_counter = 0
@torch._dynamo.disable
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Track debug information for a denoising step at a given time.
If a step with the given time already exists, it will be updated with the new data.
Otherwise, a new step will be created. Only non-None fields are updated/set.
Note: This method is excluded from torch.compile to avoid graph breaks from
operations like .item() which are incompatible with compiled graphs.
Args:
time (float | Tensor): Time parameter - used as the key to identify the step.
x_t (Tensor | None): Current latent/state tensor.
v_t (Tensor | None): Velocity from denoiser.
x1_t (Tensor | None): Denoised prediction.
correction (Tensor | None): Correction gradient tensor.
err (Tensor | None): Weighted error term.
weights (Tensor | None): Prefix attention weights.
guidance_weight (float | Tensor | None): Applied guidance weight.
inference_delay (int | None): Inference delay parameter.
execution_horizon (int | None): Execution horizon parameter.
**metadata: Additional metadata to store.
"""
if not self.enabled:
return
# Convert time to float and round to avoid float precision issues
time_value = time.item() if isinstance(time, Tensor) else time
time_key = round(time_value, 6) # Use rounded time as dictionary key
# Check if step with this time already exists
if time_key in self._steps:
# Update existing step with non-None fields
existing_step = self._steps[time_key]
if x_t is not None:
existing_step.x_t = x_t.detach().clone()
if v_t is not None:
existing_step.v_t = v_t.detach().clone()
if x1_t is not None:
existing_step.x1_t = x1_t.detach().clone()
if correction is not None:
existing_step.correction = correction.detach().clone()
if err is not None:
existing_step.err = err.detach().clone()
if weights is not None:
existing_step.weights = weights.detach().clone()
if guidance_weight is not None:
existing_step.guidance_weight = guidance_weight
if inference_delay is not None:
existing_step.inference_delay = inference_delay
if execution_horizon is not None:
existing_step.execution_horizon = execution_horizon
if metadata:
existing_step.metadata.update(metadata)
else:
# Create new step
step = DebugStep(
step_idx=self._step_counter,
x_t=x_t.detach().clone() if x_t is not None else None,
v_t=v_t.detach().clone() if v_t is not None else None,
x1_t=x1_t.detach().clone() if x1_t is not None else None,
correction=correction.detach().clone() if correction is not None else None,
err=err.detach().clone() if err is not None else None,
weights=weights.detach().clone() if weights is not None else None,
guidance_weight=guidance_weight,
time=time_value,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
metadata=metadata,
)
# Add to dictionary
self._steps[time_key] = step
self._step_counter += 1
# Enforce maxlen if set
if self._maxlen is not None and len(self._steps) > self._maxlen:
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
oldest_key = next(iter(self._steps))
del self._steps[oldest_key]
def get_all_steps(self) -> list[DebugStep]:
"""Get all recorded debug steps.
Returns:
List of all DebugStep objects (may be empty if disabled).
"""
if not self.enabled or self._steps is None:
return []
return list(self._steps.values())
def __len__(self) -> int:
"""Return the number of recorded debug steps."""
if not self.enabled or self._steps is None:
return 0
return len(self._steps)
@@ -0,0 +1,113 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Visualization utilities for RTC debug information."""
import torch
class RTCDebugVisualizer:
"""Visualizer for RTC debug information.
This class provides methods to visualize debug information collected by the Tracker,
including corrections, errors, weights, and guidance weights over denoising steps.
"""
@staticmethod
def plot_waypoints(
axes,
tensor,
start_from: int = 0,
color: str = "blue",
label: str = "",
alpha: float = 0.7,
linewidth: float = 2,
marker: str | None = None,
markersize: int = 4,
):
"""Plot trajectories across multiple dimensions.
This function plots a tensor's values across time for multiple dimensions,
with each dimension plotted on a separate axis.
Args:
axes: Array of matplotlib axes (one for each dimension).
tensor: The tensor to plot (can be torch.Tensor or numpy array).
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
start_from: Starting index for the x-axis.
color: Color for the plot lines.
label: Label for the plot legend.
alpha: Transparency level for the plot.
linewidth: Width of the plot lines.
marker: Marker style for data points (e.g., 'o', 's', '^').
markersize: Size of the markers.
"""
import numpy as np
# Handle None tensor
if tensor is None:
return
# Convert tensor to numpy if needed
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
# Handle different tensor shapes
if tensor_np.ndim == 3:
# If batch dimension present, take first batch
tensor_np = tensor_np[0]
elif tensor_np.ndim == 1:
# If 1D, reshape to (time_steps, 1)
tensor_np = tensor_np.reshape(-1, 1)
# Get dimensions
time_steps, num_dims = tensor_np.shape
# Create x-axis indices
x_indices = np.arange(start_from, start_from + time_steps)
# Plot each dimension on its corresponding axis
num_axes = len(axes) if hasattr(axes, "__len__") else 1
for dim_idx in range(min(num_dims, num_axes)):
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
# Plot the trajectory
if marker:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
marker=marker,
markersize=markersize,
)
else:
ax.plot(
x_indices,
tensor_np[:, dim_idx],
color=color,
label=label if dim_idx == 0 else "", # Only show label once
alpha=alpha,
linewidth=linewidth,
)
# Add grid and labels if not already present
if not ax.xaxis.get_label().get_text():
ax.set_xlabel("Step", fontsize=10)
if not ax.yaxis.get_label().get_text():
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
from collections import deque
import numpy as np
class LatencyTracker:
"""Tracks recent latencies and provides max/percentile queries.
Args:
maxlen (int | None): Optional sliding window size. If provided, only the
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
"""
def __init__(self, maxlen: int = 100):
self._values = deque(maxlen=maxlen)
self.reset()
def reset(self) -> None:
"""Clear all recorded latencies."""
self._values.clear()
self.max_latency = 0.0
def add(self, latency: float) -> None:
"""Add a latency sample (seconds)."""
# Ensure numeric and non-negative
val = float(latency)
if val < 0:
return
self._values.append(val)
self.max_latency = max(self.max_latency, val)
def __len__(self) -> int:
return len(self._values)
def max(self) -> float | None:
"""Return the maximum latency or None if empty."""
return self.max_latency
def percentile(self, q: float) -> float | None:
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
if not self._values:
return 0.0
q = float(q)
if q <= 0.0:
return min(self._values)
if q >= 1.0:
return self.max_latency
vals = np.array(list(self._values), dtype=np.float32)
return float(np.quantile(vals, q))
def p95(self) -> float | None:
"""Return the 95th percentile latency or None if empty."""
return self.percentile(0.95)
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Real-Time Chunking (RTC) implementation for LeRobot.
Based on Physical Intelligence's Kinetix implementation:
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
"""
import logging
import math
import torch
from torch import Tensor
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_tracker import Tracker
logger = logging.getLogger(__name__)
class RTCProcessor:
"""Real-Time Chunking processor for action chunking policies.
This class implements RTC techniques including velocity calculation,
prefix attention, and adaptive chunk processing.
"""
def __init__(self, rtc_config: RTCConfig):
self.rtc_config = rtc_config
self.tracker = None
if rtc_config.debug:
self.tracker = Tracker(
enabled=rtc_config.debug,
maxlen=rtc_config.debug_maxlen,
)
# ====================== Tracker Proxy Methods ======================
def track(
self,
time: float | Tensor,
x_t: Tensor | None = None,
v_t: Tensor | None = None,
x1_t: Tensor | None = None,
correction: Tensor | None = None,
err: Tensor | None = None,
weights: Tensor | None = None,
guidance_weight: float | Tensor | None = None,
inference_delay: int | None = None,
execution_horizon: int | None = None,
**metadata,
) -> None:
"""Proxy method to track debug information.
If tracker is None or disabled, this method does nothing.
Otherwise, it forwards the call to tracker.track().
"""
if self.tracker is not None:
self.tracker.track(
time=time,
x_t=x_t,
v_t=v_t,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
**metadata,
)
def get_all_debug_steps(self) -> list:
"""Get all debug steps from tracker.
Returns empty list if tracker is disabled or None.
"""
if self.tracker is not None:
return self.tracker.get_all_steps()
return []
def is_debug_enabled(self) -> bool:
"""Check if debug tracking is enabled.
Returns True if tracker exists and is enabled.
"""
return self.tracker is not None and self.tracker.enabled
def reset_tracker(self) -> None:
"""Reset the tracker, clearing all recorded steps.
Does nothing if tracker is None.
"""
if self.tracker is not None:
self.tracker.reset()
# ====================== End Tracker Proxy Methods ======================
def denoise_step(
self,
x_t,
prev_chunk_left_over,
inference_delay,
time,
original_denoise_step_partial,
execution_horizon=None,
) -> Tensor:
"""RTC guidance wrapper around an existing denoiser.
This method wraps an original denoising callable that only takes ``x_t`` and
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
(RTC) prefix guidance using the leftover prefix from the previous chunk.
Args:
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
is applied and the method returns ``v_t`` from the original denoiser.
inference_delay (int): Number of timesteps from the prefix to use for guidance.
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
broadcastable with ``x_t``.
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
computes the base denoised velocity given only ``x_t``.
execution_horizon (int | None): Horizon used to build prefix weights. If
``None``, defaults to ``self.rtc_config.execution_horizon``.
Returns:
Tensor: Guided velocity with the same shape as ``v_t``.
Notes:
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
right-padded with zeros to match ``T``.
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
and broadcast to ``(B, T, A)``.
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
``error = (prev_chunk_left_over - x1_t) * weights``.
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
Reference:
https://www.physicalintelligence.company/download/real_time_chunking.pdf
"""
# In the original implementation, the time goes from 0 to 1 and
# In our implementation, the time goes from 1 to 0
# So we need to invert the time
tau = 1 - time
if prev_chunk_left_over is None:
# First step, no guidance - return v_t
v_t = original_denoise_step_partial(x_t)
return v_t
x_t = x_t.clone().detach()
squeezed = False
if len(x_t.shape) < 3:
# Add batch dimension
x_t = x_t.unsqueeze(0)
squeezed = True
if len(prev_chunk_left_over.shape) < 3:
# Add batch dimension
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
if execution_horizon is None:
execution_horizon = self.rtc_config.execution_horizon
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
# because there is nothing to merge
if execution_horizon > prev_chunk_left_over.shape[1]:
execution_horizon = prev_chunk_left_over.shape[1]
batch_size = x_t.shape[0]
action_chunk_size = x_t.shape[1]
action_dim = x_t.shape[2]
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
prev_chunk_left_over = padded
assert prev_chunk_left_over.shape == x_t.shape, (
"The padded previous chunk must be the same size as the input tensor"
)
weights = (
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
.to(x_t.device)
.unsqueeze(0)
.unsqueeze(-1)
)
with torch.enable_grad():
v_t = original_denoise_step_partial(x_t)
x_t.requires_grad_(True)
x1_t = x_t - time * v_t # noqa: N806
err = (prev_chunk_left_over - x1_t) * weights
grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
result = v_t - guidance_weight * correction
# Remove the batch dimension if it was added
if squeezed:
result = result.squeeze(0)
correction = correction.squeeze(0)
x1_t = x1_t.squeeze(0)
err = err.squeeze(0)
self.track(
time=time,
x1_t=x1_t,
correction=correction,
err=err,
weights=weights,
guidance_weight=guidance_weight,
inference_delay=inference_delay,
execution_horizon=execution_horizon,
)
return result
def get_prefix_weights(self, start, end, total):
start = min(start, end)
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
weights = torch.zeros(total)
weights[:start] = 1.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
weights = torch.ones(total)
weights[end:] = 0.0
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
lin_weights = self._linweights(start, end, total)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
lin_weights = self._linweights(start, end, total)
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
weights = self._add_trailing_zeros(lin_weights, total, end)
weights = self._add_leading_ones(weights, start, total)
return weights
def _linweights(self, start, end, total):
skip_steps_at_end = max(total - end, 0)
linspace_steps = total - skip_steps_at_end - start
if end <= start or linspace_steps <= 0:
return torch.tensor([])
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
def _add_trailing_zeros(self, weights, total, end):
zeros_len = total - end
if zeros_len <= 0:
return weights
zeros = torch.zeros(zeros_len)
return torch.cat([weights, zeros])
def _add_leading_ones(self, weights, start, total):
ones_len = min(start, total)
if ones_len <= 0:
return weights
ones = torch.ones(ones_len)
return torch.cat([ones, weights])
@@ -14,8 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarms_mini import OpenArmsMiniConfig
from .openarms_mini import OpenArmsMini
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.modeling_sarm import (
SARMRewardModel,
SARMTransformer,
)
from lerobot.policies.sarm.processor_sarm import (
SARMEncodingProcessorStep,
make_sarm_pre_post_processors,
)
__all__ = ["OpenArmsMini", "OpenArmsMiniConfig"]
__all__ = [
"SARMConfig",
"SARMRewardModel",
"SARMTransformer",
"SARMEncodingProcessorStep",
"make_sarm_pre_post_processors",
]
@@ -0,0 +1,186 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("sarm")
@dataclass
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
# CLIP params
image_dim: int = 512
text_dim: int = 512
num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
# Architecture params
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
max_state_dim: int = 32
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
# Training params
batch_size: int = 64
clip_batch_size: int = 64 # Batch size for CLIP encoding
dropout: float = 0.1
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
pretrained_model_path: str | None = None
device: str | None = None
# Processor settings
image_key: str = "observation.images.top" # Key for image used from the dataset
# State key in the dataset (for normalization)
state_key: str = "observation.state"
# Populated by the processor (video_features, state_features, text_features)
input_features: dict = field(default_factory=lambda: {})
# Output features
output_features: dict = field(default_factory=lambda: {
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
})
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"LANGUAGE": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self):
super().__post_init__()
# Add the image_key as VISUAL
if self.image_key:
self.input_features[self.image_key] = PolicyFeature(
shape=(480, 640, 3),
type=FeatureType.VISUAL
)
# Add state_key as STATE
self.input_features[self.state_key] = PolicyFeature(
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
type=FeatureType.STATE
)
# Update output features with actual dimensions
self.output_features["stage"] = PolicyFeature(
shape=(self.num_frames, self.num_stages),
type=FeatureType.REWARD
)
self.output_features["progress"] = PolicyFeature(
shape=(self.num_frames, 1),
type=FeatureType.REWARD
)
# Validate configuration
if self.hidden_dim % self.num_heads != 0:
raise ValueError(
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
)
if self.max_length != self.num_frames:
raise ValueError(
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
)
if self.num_stages < 2:
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for SARM training."""
return AdamWConfig(
lr=5e-5,
weight_decay=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=5e-5,
decay_lr=5e-6,
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
"""Validate input and output features."""
pass
@property
def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
The model uses 9 frames with symmetric context around current frame:
- Frame 0: Initial frame of the episode (clamped via large negative delta)
- Frames 1-8: Symmetric context: 4 before + current + 3 after
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling (done by dataset loader):
- Early frames: backward indices clamp to 0 (first frame)
- Late frames: forward indices clamp to episode end (last frame)
This enables truly uniform sampling across entire episodes.
Returns:
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
"""
initial_frame_delta = -1_000_000
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
symmetric_deltas = [
-4 * self.frame_gap,
-3 * self.frame_gap,
-2 * self.frame_gap,
-1 * self.frame_gap,
0, # current frame
1 * self.frame_gap,
2 * self.frame_gap,
3 * self.frame_gap,
]
return [initial_frame_delta] + symmetric_deltas
@property
def action_delta_indices(self) -> None:
"""SARM is a reward model, not an action policy."""
return None
@property
def reward_delta_indices(self) -> None:
"""SARM doesn't use delta rewards."""
return None
+650
View File
@@ -0,0 +1,650 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Union, Optional
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module):
"""
SARM Transformer model for stage-aware reward prediction.
This model has a dual-head architecture:
1. Stage estimator: Predicts the high-level task stage (classification)
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
"""
def __init__(
self,
video_dim: int = 512,
text_dim: int = 512,
max_state_dim: int = 32,
hidden_dim: int = 768,
num_heads: int = 12,
num_layers: int = 8,
num_stages: int = 5,
max_length: int = 9,
dropout: float = 0.1,
temporal_proportions: list[float] | None = None
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
self.num_stages = num_stages
self.max_state_dim = max_state_dim
if temporal_proportions is None:
raise ValueError(
"temporal_proportions is required for SARM. "
"Provide subtask annotations in your dataset or set temporal_proportions in config."
)
# ᾱ_k: proportion for each stage
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
# P_k: cumulative proportion up to stage k (P_0 = 0)
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
self.register_buffer('alpha', alpha)
self.register_buffer('cumulative_prior', cumulative)
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
# Position embedding only for the first frame
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Stage estimator head (classification)
self.stage_head = nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, num_stages)
)
# Subtask estimator head (regression)
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
subtask_input_dim = hidden_dim + hidden_dim // 4
self.subtask_head = nn.Sequential(
nn.Linear(subtask_input_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 1),
nn.Sigmoid()
)
# Attention mask
self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask
return self.attention_mask
def forward(
self,
video_frames: torch.Tensor,
text_embed: torch.Tensor,
state_features: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass through the SARM transformer.
Args:
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
text_embed: Text embeddings (batch_size, text_dim)
state_features: Joint state features (batch_size, seq_len, state_dim)
Returns:
Tuple of:
- Stage logits for each frame (batch_size, seq_len, num_stages)
- Stage probabilities (batch_size, seq_len, num_stages)
- Progress predictions for each frame (batch_size, seq_len, 1)
"""
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Pad state features to max_state_dim before projection
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features
video_embed = video_embed + state_embed
# Add positional embedding to first video frame
video_embed[:, 0] += self.first_pos_embed
# Combine sequence: [text, video_frames]
sequence = torch.cat([text_embed, video_embed], dim=1)
# Get causal attention mask
seq_length = sequence.shape[1]
attention_mask = self._get_attention_mask(seq_length, sequence.device)
# Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get frame features
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
# Stage estimation
stage_logits = self.stage_head(frame_features) # [batch_size, seq_len, num_stages]
stage_probs = F.softmax(stage_logits, dim=-1) # [batch_size, seq_len, num_stages]
# Get predicted stage indices
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
# Get stage embeddings for conditioning
stage_embeds = self.stage_embedding(stage_indices)
# Concatenate frame features with stage embeddings
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
# Subtask progress estimation (conditioned on stage)
# τ̂ = within-subtask progress (0-1)
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress_preds = compute_cumulative_progress_batch(
tau_preds, stage_indices, self.alpha, self.cumulative_prior
)
return stage_logits, stage_probs, progress_preds
class SARMRewardModel(PreTrainedPolicy):
"""
SARM Reward Model for stage-aware task completion rewards.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This model combines:
- CLIP for encoding video frames AND text descriptions
- SARMTransformer for predicting task stage and progress
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
"""
name = "sarm"
config_class = SARMConfig
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Load temporal proportions from dataset
if config.temporal_proportions is None and dataset_meta is not None:
self._load_temporal_proportions(dataset_meta)
logging.info("Loading CLIP encoder")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,
text_dim=config.text_dim,
max_state_dim=config.max_state_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
num_stages=config.num_stages,
max_length=config.max_length,
dropout=config.dropout,
temporal_proportions=config.temporal_proportions
)
self.sarm_transformer.to(self.device)
logging.info(f"SARM initialized on {self.device}")
def _load_temporal_proportions(self, dataset_meta) -> None:
"""
Load pre-computed temporal proportions from dataset metadata JSON file.
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
"""
import json
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
if not proportions_path.exists():
raise ValueError(
f"Temporal proportions not found at {proportions_path}. "
"Run the subtask annotation tool first to compute and save temporal proportions."
)
with open(proportions_path, "r") as f:
temporal_proportions_dict = json.load(f)
# Sort subtask names for consistent ordering
subtask_names = sorted(temporal_proportions_dict.keys())
self.config.num_stages = len(subtask_names)
self.config.subtask_names = subtask_names
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.clip_model.to(device)
self.sarm_transformer.to(device)
return self
@torch.no_grad()
def encode_images(self, images: np.ndarray) -> np.ndarray:
"""
Encode video frames using CLIP.
Args:
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
Can also be (num_frames, H, W, C) for a single video.
Returns:
Encoded image features (num_videos, num_frames, 512) or (num_frames, 512).
"""
# Handle single video case
single_video = False
if len(images.shape) == 4:
images = images[np.newaxis, ...]
single_video = True
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
all_embeddings = []
for video in images:
video_embeddings = []
# Convert frames to PIL images for CLIP processor
frames = []
for frame in video:
if frame.shape[0] == 3: # Channel first
frame = frame.transpose(1, 2, 0)
if frame.dtype != np.uint8:
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
frames.append(Image.fromarray(frame))
# Batch process frames with CLIP
for i in range(0, len(frames), self.config.clip_batch_size):
batch = frames[i:i + self.config.clip_batch_size]
inputs = self.clip_processor(images=batch, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings from CLIP
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
video_embeddings.append(embeddings)
video_embeddings = torch.cat(video_embeddings)
all_embeddings.append(video_embeddings)
result = torch.stack(all_embeddings).numpy()
if single_video:
result = result[0]
return result
@torch.no_grad()
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Text string or list of text strings.
Returns:
Encoded text features (batch_size, 512) or (512,) for single text.
"""
if isinstance(text, str):
text = [text]
single_text = True
else:
single_text = False
# Use CLIP's tokenizer directly (avoids image processor validation issues)
tokenizer = self.clip_processor.tokenizer
# Process in batches
all_embeddings = []
for i in range(0, len(text), self.config.batch_size):
batch_text = text[i:i + self.config.batch_size]
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embeddings = self.clip_model.get_text_features(**inputs)
all_embeddings.append(text_embeddings.cpu())
result = torch.cat(all_embeddings).numpy()
if single_text:
result = result[0]
return result
@torch.no_grad()
def calculate_rewards(
self,
text_embeddings: Union[np.ndarray, torch.Tensor],
video_embeddings: Union[np.ndarray, torch.Tensor],
state_features: Optional[Union[np.ndarray, torch.Tensor]] = None,
return_all_frames: bool = False,
return_stages: bool = False
) -> Union[np.ndarray, tuple]:
"""
Calculate rewards for given text, video, and state representations.
Args:
text_embeddings: Encoded text representations (batch_size, 512)
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
state_features: Joint state features (batch_size, num_frames, state_dim)
return_all_frames: If True, return rewards for all frames
return_stages: If True, also return stage predictions
Returns:
If return_stages=False:
Reward values (batch_size,) or (batch_size, num_frames)
If return_stages=True:
Tuple of (rewards, stage_probs)
"""
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
if state_features is not None and isinstance(state_features, np.ndarray):
state_features = torch.tensor(state_features, dtype=torch.float32)
# Handle single sample case
if text_embeddings.dim() == 1:
text_embeddings = text_embeddings.unsqueeze(0)
video_embeddings = video_embeddings.unsqueeze(0)
if state_features is not None:
state_features = state_features.unsqueeze(0)
single_sample = True
else:
single_sample = False
# Process in batches
all_rewards = []
all_stage_probs = []
for i in range(0, len(video_embeddings), self.config.batch_size):
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
batch_states = None
if state_features is not None:
batch_states = state_features[i:i + self.config.batch_size].to(self.device)
# Get predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
batch_videos.float(), batch_texts.float(), batch_states.float() if batch_states is not None else None
)
if return_all_frames:
all_rewards.append(progress_preds.squeeze(-1).cpu())
else:
# Return only last frame reward
all_rewards.append(progress_preds[:, -1, 0].cpu())
if return_stages:
all_stage_probs.append(stage_probs.cpu())
rewards = torch.cat(all_rewards).numpy()
if single_sample:
rewards = rewards[0] if not return_all_frames else rewards[0]
if return_stages:
stage_probs = torch.cat(all_stage_probs).numpy()
if single_sample:
stage_probs = stage_probs[0]
return rewards, stage_probs
return rewards
def train(self, mode: bool = True):
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
super().train(mode)
self.clip_model.eval()
self.sarm_transformer.train(mode)
return self
def eval(self):
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
return self.train(False)
def parameters(self):
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
return self.sarm_transformer.parameters()
def get_optim_params(self):
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
return self.parameters()
def reset(self):
"""Required by PreTrainedPolicy but not used for reward models."""
pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for reward models."""
raise NotImplementedError("SARM model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _apply_temporal_augmentation(
self,
video: torch.Tensor,
progress: torch.Tensor,
state: torch.Tensor | None,
max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
This helps the model learn to handle non-monotonic progress (failures, recoveries).
Appends 1-4 reversed frames to simulate going backwards in task progress.
"""
num_reverse = random.randint(1, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original)
reversed_video = video.flip(0)[1:num_reverse + 1]
reversed_progress = progress.flip(0)[1:num_reverse + 1]
# Concatenate and trim
video = torch.cat([video, reversed_video], dim=0)[:max_length]
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
if state is not None:
reversed_state = state.flip(0)[1:num_reverse + 1]
state = torch.cat([state, reversed_state], dim=0)[:max_length]
return video, progress, state
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
"""Pad or trim tensor to target length."""
current_len = tensor.shape[0]
if current_len == target_len:
return tensor
if current_len < target_len:
padding = target_len - current_len
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
return tensor[:target_len]
def forward(self, batch):
"""
Forward pass for SARM reward model training.
Uses annotation-based progress targets following SARM paper Eq. 2:
yt = Pk-1 + α̅k × τt
where:
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
- Pk-1 is cumulative prior (sum of previous subtask proportions)
- α̅k is the temporal proportion for subtask k
Args:
batch: Dictionary with 'observation' containing:
- 'video_features': (B, T, 512) pre-encoded video features
- 'text_features': (B, 512) pre-encoded text features (CLIP)
- 'state_features': (B, T, state_dim) joint state features
- 'stage_labels': (B, T) stage labels from annotations
- 'progress_targets': (B, T, 1) progress targets from annotations
Returns:
Tuple of (total_loss, output_dict with loss components)
"""
observation = batch.get('observation', batch)
# Extract required features
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features').to(self.device)
batch_size = video_features.shape[0]
max_length = self.config.num_frames
# Ensure 3D video features (B, T, D)
if video_features.dim() == 2:
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
if state_features is not None and state_features.dim() == 2:
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
# Get annotation-based progress targets (required for SARM paper formula)
progress_from_annotations = observation.get('progress_targets')
if progress_from_annotations is None:
raise ValueError("progress_targets from annotations is required for SARM training")
progress_from_annotations = progress_from_annotations.to(self.device)
if progress_from_annotations.dim() == 2:
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
# Process each sample: apply temporal REWIND augmentation
processed_videos = []
processed_states = []
progress_targets = []
for i in range(batch_size):
video = video_features[i]
state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5:
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
# Ensure correct sequence length
video = self._ensure_sequence_length(video, max_length)
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
if state is not None:
state = self._ensure_sequence_length(state, max_length)
processed_videos.append(video)
progress_targets.append(progress)
if state is not None:
processed_states.append(state)
# Stack into batches
processed_videos = torch.stack(processed_videos)
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
processed_states = torch.stack(processed_states) if processed_states else None
# Get model predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
processed_videos, text_features, processed_states
)
# Compute progress loss (MSE)
progress_loss = F.mse_loss(progress_preds, progress_targets)
output_dict = {'progress_loss': progress_loss.item()}
total_loss = progress_loss
# Compute stage loss (cross-entropy)
stage_labels = observation.get('stage_labels')
if stage_labels is None:
raise ValueError("stage_labels from annotations is required for SARM training")
stage_labels = stage_labels.to(self.device)
if stage_labels.dim() == 1:
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
stage_loss = compute_stage_loss(stage_logits, stage_labels)
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
# Misaligned loss: 20% probability
if random.random() < 0.2:
shuffle_idx = torch.randperm(batch_size, device=self.device)
_, _, misaligned_preds = self.sarm_transformer(
processed_videos, text_features[shuffle_idx], processed_states
)
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
total_loss = total_loss + misaligned_loss
output_dict['misaligned_loss'] = misaligned_loss.item()
output_dict['total_loss'] = total_loss.item()
return total_loss, output_dict
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
_, _, num_stages = stage_logits.shape
stage_logits_flat = stage_logits.reshape(-1, num_stages)
target_stages_flat = target_stages.reshape(-1)
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
return loss
+644
View File
@@ -0,0 +1,644 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
import torch
from PIL import Image
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
from lerobot.processor.core import TransitionKey
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import (
ProcessorStep,
PolicyProcessorPipeline,
PolicyAction,
DeviceProcessorStep,
AddBatchDimensionProcessorStep,
NormalizerProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
from_tensor_to_numpy,
)
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.configs.types import PolicyFeature, FeatureType
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):
"""ProcessorStep that encodes images and text with CLIP."""
def __init__(
self,
config: SARMConfig,
image_key: str | None = None,
dataset_meta = None,
dataset_stats: dict | None = None,
):
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
self.subtask_names = self.config.subtask_names
self.device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
)
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
return ep_idx
return 0
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index."""
if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
# If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for symmetric bidirectional pattern.
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Backward indices clamp to ep_start (first frame)
- Forward indices clamp to ep_end - 1 (last frame)
"""
indices = []
indices.append(ep_start) # Initial frame is always episode start
# Symmetric context: 4 before, current, 3 after
num_before = 4
num_after = 3
last_valid_frame = ep_end - 1
# Frames before current (clamp to first frame)
for i in range(num_before, 0, -1):
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
indices.append(idx)
# Current frame
indices.append(frame_idx)
# Frames after current (clamp to last frame)
for i in range(1, num_after + 1):
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
indices.append(idx)
return torch.tensor(indices)
def _compute_episode_metadata(
self,
frame_indices: np.ndarray,
episode_indices: np.ndarray,
num_frames: int,
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute episode metadata for all samples.
Returns:
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
"""
absolute_indices_list = []
remaining_lengths = []
episode_lengths = []
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_lengths.append(ep_end - ep_start)
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item())
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
def _compute_stage_and_progress_for_frame(
self,
current_frame: int,
subtask_names: list,
subtask_start_frames: list,
subtask_end_frames: list,
transition_smoothing_frames: int = 15,
) -> tuple[int, float, dict[int, float] | None]:
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
Implements SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Additionally computes soft stage labels near transitions to mitigate discrete jumps
in the stage classifier. Near stage boundaries, labels are blended between adjacent
stages to encourage smoother predictions.
Args:
current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode
subtask_start_frames: List of subtask start frames
subtask_end_frames: List of subtask end frames
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
Returns:
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
- stage_idx: Hard stage index (for compatibility)
- cumulative_progress: Progress value in [0, 1]
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
"""
# Get temporal proportions as list for compute_cumulative_progress
temporal_proportions_list = [
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
]
num_stages = len(self.subtask_names)
# Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask, get its global index
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Compute τ_t using utility function (Paper Formula 2)
tau = compute_tau(current_frame, start_frame, end_frame)
# Compute cumulative progress using utility function (Paper Formula 2)
cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list
)
# Compute soft stage labels near transitions
soft_stage_labels = None
frames_from_start = current_frame - start_frame
frames_to_end = end_frame - current_frame
if frames_from_start < transition_smoothing_frames and j > 0:
# Near start of stage - blend with previous stage
blend = frames_from_start / transition_smoothing_frames
prev_name = subtask_names[j - 1]
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
# Near end of stage - blend with next stage
blend = frames_to_end / transition_smoothing_frames
next_name = subtask_names[j + 1]
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
return stage_idx, cumulative_progress, soft_stage_labels
# No matching subtask found
if current_frame < subtask_start_frames[0]:
return 0, 0.0, None
elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0, None
else:
# Between subtasks - use previous subtask's end state (tau = 1.0)
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Completed subtask, so tau = 1.0
cumulative_progress = compute_cumulative_progress_batch(
1.0, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress, None
return 0, 0.0, None
def _compute_labels_for_sample(
self,
frame_idx: int,
ep_idx: int,
seq_len: int,
episodes_df: pd.DataFrame,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Before episode start: clamp to frame 0 (progress ~0%)
- After episode end: clamp to last frame (progress ~100%)
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
Args:
frame_idx: The frame index for this sample
ep_idx: The episode index
seq_len: Number of frames in the sequence
episodes_df: DataFrame with episode metadata
Returns:
Tuple of (stage_labels, progress_targets, soft_stage_labels):
- stage_labels: (T,) hard stage indices
- progress_targets: (T, 1) progress values
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
"""
# Check if episode has valid annotations
if ep_idx >= len(episodes_df):
return None, None, None
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
return None, None, None
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
ep_length = ep_end - ep_start
last_valid_frame = ep_length - 1
num_stages = len(self.subtask_names)
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
soft_labels_list = [] # List of soft label dicts (or None)
has_any_soft_labels = False
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
num_before = 4
num_after = 3
for i in range(seq_len):
if i == 0:
# Position 0: Initial frame of the episode
current_frame = 0 # Relative to episode start
elif i <= num_before:
# Positions 1-4: frames before current (with clamping to first frame)
offset = -(num_before - i + 1) * self.config.frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
elif i == num_before + 1:
# Position 5: current frame
current_frame = frame_idx - ep_start
else:
# Positions 6-8: frames after current (with clamping to last frame)
offset = (i - num_before - 1) * self.config.frame_gap
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
)
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
soft_labels_list.append(soft_stage_labels)
if soft_stage_labels is not None:
has_any_soft_labels = True
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
# Convert soft labels to tensor if any exist
soft_stage_labels_tensor = None
if has_any_soft_labels:
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for i, soft_dict in enumerate(soft_labels_list):
if soft_dict is not None:
for stage_idx, prob in soft_dict.items():
soft_stage_labels_tensor[i, stage_idx] = prob
else:
# Use hard one-hot label
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
return stage_labels, progress_targets, soft_stage_labels_tensor
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
Args:
frame_index: Current frame index or tensor of indices
episode_index: Episode index or tensor of indices
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
- stage_labels: (B, T) hard stage indices
- progress_targets: (B, T, 1) progress values
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
"""
if self.temporal_proportions is None or episode_index is None:
return None, None, None
# Normalize inputs to numpy arrays
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length
if video_features is not None and video_features.dim() >= 2:
seq_len = video_features.shape[1]
else:
seq_len = 1
episodes_df = self.dataset_meta.episodes.to_pandas()
num_stages = len(self.subtask_names)
all_stage_labels = []
all_progress_targets = []
all_soft_stage_labels = []
has_any_soft_labels = False
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
int(frame_idx), int(ep_idx), seq_len, episodes_df
)
if stage_labels is None:
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
all_soft_stage_labels.append(None)
else:
all_stage_labels.append(stage_labels)
all_progress_targets.append(progress_targets)
all_soft_stage_labels.append(soft_labels)
if soft_labels is not None:
has_any_soft_labels = True
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
# Stack soft labels if any exist
stacked_soft_labels = None
if has_any_soft_labels:
soft_labels_tensors = []
for i, soft_labels in enumerate(all_soft_stage_labels):
if soft_labels is not None:
soft_labels_tensors.append(soft_labels)
else:
# Create one-hot from hard labels
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for t in range(seq_len):
one_hot[t, all_stage_labels[i][t]] = 1.0
soft_labels_tensors.append(one_hot)
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition."""
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION)
image = observation.get(self.image_key)
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key
state_data = observation.get(state_key)
if isinstance(state_data, torch.Tensor):
state_tensor = state_data.float()
else:
state_tensor = torch.tensor(state_data, dtype=torch.float32)
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Get task description from dataset (complementary_data["task"])
task = comp_data.get('task')
if isinstance(task, list):
# If batch, take first task (assuming same task for all items in batch)
task = task[0] if task else ""
# Encode text with CLIP
batch_size = video_features.shape[0]
observation['text_features'] = self._encode_text_clip(task, batch_size)
frame_index = comp_data.get('index')
episode_index = comp_data.get('episode_index')
if frame_index is None:
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
# Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None:
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features
if video_features.dim() >= 2:
num_frames = video_features.shape[1]
else:
num_frames = 1
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
frame_indices, episode_indices, num_frames
)
observation['absolute_frame_indices'] = abs_indices
observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
frame_index, episode_index, video_features
)
if stage_labels is not None:
observation['stage_labels'] = stage_labels
observation['progress_targets'] = progress_targets
if soft_stage_labels is not None:
observation['soft_stage_labels'] = soft_stage_labels
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
@torch.no_grad()
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
"""Encode a batch of images using CLIP.
Args:
images: Batched images with shape: (B, T, C, H, W)
Returns:
Encoded feature vectors with shape (B, T, 512)
"""
batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:])
# Convert to list of PIL images
num_frames = images.shape[0]
images_list = []
for i in range(num_frames):
img = images[i]
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
img = img.transpose(1, 2, 0)
# Handle single channel
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
# Convert to uint8
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
images_list.append(Image.fromarray(img))
# Encode each batch
all_embeddings = []
for i in range(0, num_frames, self.config.clip_batch_size):
batch_imgs = images_list[i:i + self.config.clip_batch_size]
# Process with CLIP
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get image embeddings
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
all_embeddings.append(embeddings)
# Concatenate all embeddings
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
# Reshape back
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
return all_embeddings
@torch.no_grad()
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Task description text to encode
batch_size: Batch size to replicate for
Returns:
Encoded text features with shape (B, 512)
"""
# Use CLIP's tokenizer directly for text
tokenizer = self.clip_processor.tokenizer
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get text features from CLIP
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
# Replicate for batch (B, 512)
text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Add encoded features to the observation features."""
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(self.config.num_frames, self.config.image_dim)
)
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
type=FeatureType.LANGUAGE,
shape=(self.config.text_dim,)
)
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.max_state_dim)
)
return features
def make_sarm_pre_post_processors(
config: SARMConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create pre-processor and post-processor pipelines for SARM.
The pre-processing pipeline:
1. Adds batch dimension
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
3. SARMEncodingProcessorStep:
- Encodes images with CLIP
- Pads states to max_state_dim
- Encodes text with CLIP
4. Moves data to device
The post-processing pipeline:
1. Moves data to CPU
"""
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
SARMEncodingProcessorStep(
config=config,
dataset_meta=dataset_meta,
dataset_stats=dataset_stats
),
DeviceProcessorStep(device=config.device),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=[DeviceProcessorStep(device="cpu")],
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+257
View File
@@ -0,0 +1,257 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn.functional as F
from typing import Sequence, Any
from pydantic import BaseModel, Field
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
Implements SARM Paper Formula (1):
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where:
- M is the number of trajectories (episodes)
- L_{i,k} is the duration of subtask k in trajectory i
- T_i is the total duration of trajectory i
This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length.
Args:
annotations: Dict mapping episode index to SubtaskAnnotation object.
Each annotation has a .subtasks list where each subtask has:
- .name: subtask name
- .timestamps.start: start time as "MM:SS" string
- .timestamps.end: end time as "MM:SS" string
fps: Frames per second (unused, kept for API compatibility)
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k).
Proportions are normalized to sum to 1.0.
"""
subtask_proportions: dict[str, list[float]] = {}
for annotation in annotations.values():
total_duration = 0
durations: dict[str, int] = {}
for subtask in annotation.subtasks:
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Calculate L_{i,k} / T_i for each subtask in this trajectory
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
if not subtask_proportions:
return {}
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to ensure sum = 1
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def compute_tau(
current_frame: int | float,
subtask_start: int | float,
subtask_end: int | float,
) -> float:
"""
Compute within-subtask normalized time τ_t.
Implements part of SARM Paper Formula (2):
τ_t = (t - s_k) / (e_k - s_k) [0, 1]
where:
- t is the current frame
- s_k is the start frame of subtask k
- e_k is the end frame of subtask k
Args:
current_frame: Current frame index (t)
subtask_start: Start frame of the subtask (s_k)
subtask_end: End frame of the subtask (e_k)
Returns:
Within-subtask progress τ_t [0, 1]
"""
subtask_duration = subtask_end - subtask_start
if subtask_duration <= 0:
return 1.0
tau = (current_frame - subtask_start) / subtask_duration
return float(np.clip(tau, 0.0, 1.0))
def compute_cumulative_progress_batch(
tau: torch.Tensor | float,
stage_indices: torch.Tensor | int,
alpha: torch.Tensor | Sequence[float],
cumulative_prior: torch.Tensor | None = None,
) -> torch.Tensor | float:
"""
Compute cumulative normalized progress from within-subtask progress.
This function implements the core formula used in SARM for both:
**Formula 2 (Training labels):**
y_t = P_{k-1} + ᾱ_k × τ_t [0, 1]
Used to compute ground-truth progress labels from subtask annotations.
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
- k is the known subtask from annotations
**Formula 4 (Inference predictions):**
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} [0, 1]
Used to convert model outputs to cumulative progress during inference.
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
The formulas are mathematically identical; only the source of inputs differs:
- Training: τ and k from annotations ground-truth labels
- Inference: τ̂ and Ŝ from model predicted progress
where:
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
- τ is within-subtask progress [0, 1]
This ensures:
- y at start of subtask k = P_{k-1}
- y at end of subtask k = P_k
Supports both scalar and batched tensor inputs:
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
Args:
tau: Within-subtask progress τ [0, 1].
For training: computed from frame position in annotated subtask.
For inference: predicted by subtask MLP head.
Scalar float or Tensor with shape (..., 1)
stage_indices: Index of current subtask k (0-indexed).
For training: known from annotations.
For inference: predicted via argmax(stage_probs) (Formula 3).
Scalar int or Tensor with shape (...)
alpha: Temporal proportions with shape (num_stages,) or Sequence[float].
Computed from dataset annotations using Formula 1.
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
If None, will be computed from alpha.
Returns:
Cumulative progress y [0, 1].
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
"""
if not isinstance(tau, torch.Tensor):
if not alpha:
raise ValueError("alpha (temporal_proportions) cannot be empty")
if isinstance(alpha, torch.Tensor):
alpha_list = alpha.tolist()
else:
alpha_list = list(alpha)
if stage_indices < 0 or stage_indices >= len(alpha_list):
raise ValueError(
f"stage_indices {stage_indices} out of range "
f"for {len(alpha_list)} subtasks"
)
# P_{k-1} = sum of proportions for subtasks 0 to k-1
P_k_minus_1 = sum(alpha_list[:stage_indices])
# ᾱ_k = proportion for current subtask
alpha_k = alpha_list[stage_indices]
# y_t = P_{k-1} + ᾱ_k × τ_t
y_t = P_k_minus_1 + alpha_k * tau
return float(np.clip(y_t, 0.0, 1.0))
if not isinstance(alpha, torch.Tensor):
alpha = torch.tensor(alpha, dtype=torch.float32)
# Compute cumulative_prior if not provided
if cumulative_prior is None:
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
# P_{k-1} for each predicted stage
P_k_minus_1 = cumulative_prior[stage_indices]
# ᾱ_k for each predicted stage
alpha_k = alpha[stage_indices]
# ŷ = P_{k-1} + ᾱ_k × τ̂
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
return progress
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
current_dim = state.shape[-1]
if current_dim >= max_state_dim:
return state[..., :max_state_dim] # Truncate if larger
# Pad with zeros on the right
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
return F.pad(state, padding, mode='constant', value=0)
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
def __post_init__(self):
super().__post_init__()
+101 -19
View File
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.utils.utils import get_safe_dtype
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
super().__init__(config)
config.validate_features()
self.config = config
self.model = VLAFlowMatching(config)
self.init_rtc_processor()
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
self.reset()
def reset(self):
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Lets create processor if the config provided
# If RTC is not enabled - we still can track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# In case of calling init_rtc_processor after the model is created
# We need to set the rtc_processor to the model
# During the normal initialization process the model is not created yet
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def _get_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def predict_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
actions = self._get_action_chunk(batch, noise, **kwargs)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
if self._check_get_actions_condition():
actions = self._get_action_chunk(batch, noise)
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
return self._queues[ACTION].popleft()
def _check_get_actions_condition(self) -> bool:
return len(self._queues[ACTION]) == 0
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
"""
def __init__(self, config: SmolVLAConfig):
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
super().__init__()
self.config = config
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
self.add_image_special_tokens = self.config.add_image_special_tokens
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
def sample_actions(
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(
+51
View File
@@ -22,6 +22,8 @@ import numpy as np
import torch
from torch import nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.utils import build_dataset_frame
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
@@ -198,3 +200,52 @@ def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict])
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
}
return act_processed_policy
def raise_feature_mismatch_error(
provided_features: set[str],
expected_features: set[str],
) -> None:
"""
Raises a standardized ValueError for feature mismatches between dataset/environment and policy config.
"""
missing = expected_features - provided_features
extra = provided_features - expected_features
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
raise ValueError(
f"Feature mismatch between dataset/environment and policy config.\n"
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
f"Please ensure your dataset and policy use consistent feature names.\n"
f"If your dataset uses different observation keys (e.g., cameras named differently), "
f"use the `--rename_map` argument, for example:\n"
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
f'"observation.images.top": "observation.images.camera2"}}\''
)
def validate_visual_features_consistency(
cfg: PreTrainedConfig,
features: dict[str, PolicyFeature],
) -> None:
"""
Validates visual feature consistency between a policy config and provided dataset/environment features.
Validation passes if EITHER:
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
Args:
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
"""
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
# Accept if either direction is a subset
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
if not (policy_subset_of_dataset or dataset_subset_of_policy):
raise_feature_mismatch_error(provided_visuals, expected_visuals)
+2 -1
View File
@@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
+154
View File
@@ -0,0 +1,154 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result
-20
View File
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarms_follower import OpenArmsFollowerConfig
from .openarms_follower import OpenArmsFollower
__all__ = ["OpenArmsFollower", "OpenArmsFollowerConfig"]
@@ -1,118 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Dict, Optional
from lerobot.cameras import CameraConfig
from lerobot.motors.damiao.tables import MotorType
from ..config import RobotConfig
@RobotConfig.register_subclass("openarms_follower")
@dataclass
class OpenArmsFollowerConfig(RobotConfig):
"""Configuration for the OpenArms follower robot with Damiao motors."""
# CAN interfaces - one per arm
# Right arm CAN interface (e.g., "can0")
# Left arm CAN interface (e.g., "can1")
# Linux: "can0", "can1", etc.
# macOS: "/dev/cu.usbmodem*" (serial device)
port_right: str = "can0" # CAN interface for right arm
port_left: str = "can1" # CAN interface for left arm
# CAN interface type: "socketcan" (Linux), "slcan" (macOS/serial), or "auto" (auto-detect)
can_interface: str = "socketcan"
# CAN FD settings (OpenArms uses CAN FD by default)
use_can_fd: bool = True
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
# Whether to disable torque when disconnecting
disable_torque_on_disconnect: bool = True
# Safety limit for relative target positions
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
max_relative_target: Optional[float | Dict[str, float]] = None
# Camera configurations
cameras: Dict[str, CameraConfig] = field(default_factory=dict)
# Motor configuration for OpenArms (7 DOF per arm)
# Maps motor names to (send_can_id, recv_can_id, motor_type)
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms uses 4 types of motors:
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
# - DM4340P and DM4340 for shoulder rotation and elbow
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
motor_config: Dict[str, tuple[int, int, str]] = field(default_factory=lambda: {
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
})
# MIT control parameters for position control (used in send_action)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
position_kp: list[float] = field(default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0])
position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2])
# Damping gains for stability when applying torque compensation (gravity/friction)
# Used when kp=0 and only torque is applied
damping_kd: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1])
# Friction model parameters: τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
# From OpenArms config/follower.yaml
friction_fc: list[float] = field(default_factory=lambda: [0.306, 0.306, 0.40, 0.166, 0.050, 0.093, 0.172, 0.0512]) # Coulomb friction [Nm]
friction_k: list[float] = field(default_factory=lambda: [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000]) # tanh steepness
friction_fv: list[float] = field(default_factory=lambda: [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084]) # Viscous friction [Nm·s/rad]
friction_fo: list[float] = field(default_factory=lambda: [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050]) # Offset torque [Nm]
# Calibration parameters
calibration_mode: str = "manual" # "manual" or "auto"
zero_position_on_connect: bool = False # Set zero position on connect
# Joint limits for position clipping (degrees)
# Format: [min, max] for each joint
# These limits clip commands in send_action to prevent mechanical damage
joint_limits_right: Dict[str, tuple[float, float]] = field(default_factory=lambda: {
"joint_1": (-75.0, 75.0),
"joint_2": (-9.0, 90.0),
"joint_3": (-85.0, 85.0),
"joint_4": (0.0, 135.0),
"joint_5": (-85.0, 85.0),
"joint_6": (-40.0, 40.0),
"joint_7": (-80.0, 80.0),
"gripper": (-65.0, 0.0),
})
joint_limits_left: Dict[str, tuple[float, float]] = field(default_factory=lambda: {
"joint_1": (-75.0, 75.0),
"joint_2": (-90.0, 9.0),
"joint_3": (-85.0, 85.0),
"joint_4": (0.0, 135.0),
"joint_5": (-85.0, 85.0),
"joint_6": (-40.0, 40.0),
"joint_7": (-80.0, 80.0),
"gripper": (-65.0, 0.0),
})
@@ -1,698 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from functools import cached_property
from typing import Any, Dict, Optional
import numpy as np
import pinocchio as pin
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.motors.damiao.tables import MotorType
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_openarms_follower import OpenArmsFollowerConfig
logger = logging.getLogger(__name__)
class OpenArmsFollower(Robot):
"""
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
The arm uses Damiao motors in MIT control mode.
"""
config_class = OpenArmsFollowerConfig
name = "openarms_follower"
def __init__(self, config: OpenArmsFollowerConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES # Always use degrees for Damiao motors
# Right arm motors (on port_right)
# Each arm uses the same CAN IDs since they're on separate buses
motors_right = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(send_id, motor_type_str, norm_mode_body)
motor.recv_id = recv_id
motor.motor_type = getattr(MotorType, motor_type_str.upper().replace("-", "_"))
motors_right[motor_name] = motor
# Left arm motors (on port_left, same IDs as right since separate bus)
motors_left = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(send_id, motor_type_str, norm_mode_body)
motor.recv_id = recv_id
motor.motor_type = getattr(MotorType, motor_type_str.upper().replace("-", "_"))
motors_left[motor_name] = motor
# Initialize separate Damiao motors buses (one per arm) with CAN FD support
self.bus_right = DamiaoMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration={k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")},
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
bitrate=self.config.can_bitrate,
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
self.bus_left = DamiaoMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration={k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")},
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
bitrate=self.config.can_bitrate,
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
# Initialize cameras
self.cameras = make_cameras_from_configs(config.cameras)
# Cache for last valid camera frames (to avoid blocking on slow USB reads)
self.camera_frame_cache = {key: None for key in self.cameras.keys()}
# Initialize Pinocchio robot model for dynamics (optional)
self.pin_robot = None
try:
# Load URDF - try external path first (with meshes), then repository
import os
from os.path import expanduser, dirname
# Try external URDF with meshes first
external_urdf_path = expanduser("~/Documents/openarm_description/openarm_bimanual_pybullet.urdf")
if os.path.exists(external_urdf_path):
urdf_path = external_urdf_path
urdf_dir = dirname(urdf_path)
self.pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, urdf_dir)
self.pin_robot.data = self.pin_robot.model.createData()
logger.info(f"Loaded OpenArms URDF for dynamics computation from {urdf_path}")
except Exception as e:
logger.warning(f"Could not load URDF for dynamics: {e}. Gravity compensation will not be available.")
@property
def _motors_ft(self) -> Dict[str, type]:
"""Motor features for observation and action spaces."""
features = {}
# Right arm motors - only positions stored in dataset
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
# Left arm motors - only positions stored in dataset
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def _cameras_ft(self) -> Dict[str, tuple]:
"""Camera features for observation space."""
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3)
for cam in self.cameras
}
@cached_property
def observation_features(self) -> Dict[str, type | tuple]:
"""Combined observation features from motors and cameras."""
return {**self._motors_ft, **self._cameras_ft}
@cached_property
def action_features(self) -> Dict[str, type]:
"""Action features (motor positions only)."""
return self._motors_ft
@property
def is_connected(self) -> bool:
"""Check if robot is connected."""
return (self.bus_right.is_connected and
self.bus_left.is_connected and
all(cam.is_connected for cam in self.cameras.values()))
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the robot and optionally calibrate.
We assume that at connection time, the arms are in a safe rest position,
and torque can be safely disabled to run calibration if needed.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to both CAN buses
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
# Run calibration if needed
if calibrate:
logger.info(
"No calibration found or overwriting calibration. Running calibration..."
)
self.calibrate()
# Connect cameras
for cam in self.cameras.values():
cam.connect()
# Configure motors
self.configure()
# Optionally set zero position
if self.config.zero_position_on_connect:
logger.info("Setting current position as zero...")
self.bus_right.set_zero_position()
self.bus_left.set_zero_position()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
"""Check if robot is calibrated."""
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArms robot.
The calibration procedure:
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position
4. Record range of motion for each joint
5. Save calibration
"""
if self.calibration:
# Ask user whether to use existing calibration
user_input = input(
f"Press ENTER to use existing calibration for {self.id}, "
f"or type 'c' and press ENTER to run new calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
# Split calibration for each bus
cal_right = {k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")}
cal_left = {k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
# Calibrate each arm separately
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
def _calibrate_arm(self, arm_name: str, bus: DamiaoMotorsBus) -> None:
"""Calibrate a single arm."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
# Disable torque for manual positioning
bus.disable_torque()
time.sleep(0.1)
# Step 1: Set zero position
input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
# Set current position as zero for all motors
bus.set_zero_position()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
# Automatically set range to -90° to +90° for all joints
print(
f"\nAutomatically setting range: -90° to +90° for all joints"
)
# Create calibration data with fixed ranges
if self.calibration is None:
self.calibration = {}
for motor_name, motor in bus.motors.items():
# Prefix motor name with arm name for storage
prefixed_name = f"{arm_name}_{motor_name}"
# Use -90 to +90 for all joints and gripper (integers required)
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=0, # Normal direction
homing_offset=0, # Already set via set_zero_position
range_min=-90, # -90 degrees (integer)
range_max=90, # +90 degrees (integer)
)
logger.info(f" {prefixed_name}: range set to [-90°, +90°]")
# Write calibration to this arm's motors
cal_for_bus = {k.replace(f"{arm_name}_", ""): v for k, v in self.calibration.items() if k.startswith(f"{arm_name}_")}
bus.write_calibration(cal_for_bus)
# Re-enable torque
bus.enable_torque()
# Save calibration after each arm
self._save_calibration()
def configure(self) -> None:
"""Configure motors with appropriate settings."""
# Configure right arm
with self.bus_right.torque_disabled():
self.bus_right.configure_motors()
# Configure left arm
with self.bus_left.torque_disabled():
self.bus_left.configure_motors()
def setup_motors(self) -> None:
raise NotImplementedError("Motor ID configuration is typically done via manufacturer tools for CAN motors.")
def get_observation(self) -> Dict[str, Any]:
"""
Get current observation from robot including position, velocity, and torque.
OPTIMIZED: Reads all motor states (pos/vel/torque) in one CAN refresh cycle
instead of 3 separate reads.
Note: Velocity and torque are read but not stored in dataset (only used for
internal calculations). Only positions and camera images are stored.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
# Detailed profiling for bottleneck analysis
timings = {}
# OPTIMIZED: Use sync_read_all_states to get pos/vel/torque in one go
t0 = time.perf_counter()
right_states = self.bus_right.sync_read_all_states()
timings["right_motors"] = (time.perf_counter() - t0) * 1000
for motor in self.bus_right.motors:
state = right_states.get(motor, {})
obs_dict[f"right_{motor}.pos"] = state.get("position", 0.0)
obs_dict[f"right_{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"right_{motor}.torque"] = state.get("torque", 0.0)
# OPTIMIZED: Use sync_read_all_states to get pos/vel/torque in one go
t0 = time.perf_counter()
left_states = self.bus_left.sync_read_all_states()
timings["left_motors"] = (time.perf_counter() - t0) * 1000
for motor in self.bus_left.motors:
state = left_states.get(motor, {})
obs_dict[f"left_{motor}.pos"] = state.get("position", 0.0)
obs_dict[f"left_{motor}.vel"] = state.get("velocity", 0.0)
obs_dict[f"left_{motor}.torque"] = state.get("torque", 0.0)
# Capture images from cameras (with individual timing)
# Use async_read with very short timeout to avoid blocking on slow USB cameras
for cam_key, cam in self.cameras.items():
t0 = time.perf_counter()
try:
# Use 5ms timeout - if frame isn't ready, reuse last frame
frame = cam.async_read(timeout_ms=5)
self.camera_frame_cache[cam_key] = frame # Update cache
obs_dict[cam_key] = frame
except TimeoutError:
# If no new frame available, reuse last valid frame from cache
# This prevents blocking the entire control loop on slow USB reads
if self.camera_frame_cache[cam_key] is not None:
obs_dict[cam_key] = self.camera_frame_cache[cam_key]
logger.debug(f"Camera {cam_key} timeout, reusing cached frame")
# Store timing with padded name to align output (e.g. "left_wrist ")
timings[f"{cam_key:14s}"] = (time.perf_counter() - t0) * 1000
# Log detailed timings (for debugging slow observations)
if logger.isEnabledFor(logging.DEBUG):
total_time = sum(timings.values())
breakdown = " | ".join([f"{k}: {v:.1f}ms" for k, v in timings.items()])
logger.debug(f"{self} get_observation: {total_time:.1f}ms total | {breakdown}")
# Store timings in obs_dict for external profiling
obs_dict["_timing_breakdown"] = timings
return obs_dict
def send_action(
self,
action: Dict[str, Any],
custom_kp: Optional[Dict[str, float]] = None,
custom_kd: Optional[Dict[str, float]] = None
) -> Dict[str, Any]:
"""
Send action command to robot.
The action magnitude may be clipped based on safety limits.
Args:
action: Dictionary with motor positions (e.g., "right_joint_1.pos", "left_joint_2.pos")
custom_kp: Optional custom kp gains per motor (e.g., {"right_joint_1": 120.0, "left_joint_2": 150.0})
custom_kd: Optional custom kd gains per motor (e.g., {"right_joint_1": 1.5, "left_joint_2": 2.0})
Returns:
The action actually sent (potentially clipped)
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Extract motor positions from action and split by arm
goal_pos_right = {}
goal_pos_left = {}
for key, val in action.items():
if key.endswith(".pos"):
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
# Remove "right_" prefix for bus access
goal_pos_right[motor_name.removeprefix("right_")] = val
elif motor_name.startswith("left_"):
# Remove "left_" prefix for bus access
goal_pos_left[motor_name.removeprefix("left_")] = val
# Apply joint limit clipping to right arm
for motor_name, position in goal_pos_right.items():
if motor_name in self.config.joint_limits_right:
min_limit, max_limit = self.config.joint_limits_right[motor_name]
clipped_position = max(min_limit, min(max_limit, position))
if clipped_position != position:
logger.debug(f"Clipped right_{motor_name} from {position:.2f}° to {clipped_position:.2f}°")
goal_pos_right[motor_name] = clipped_position
# Apply joint limit clipping to left arm
for motor_name, position in goal_pos_left.items():
if motor_name in self.config.joint_limits_left:
min_limit, max_limit = self.config.joint_limits_left[motor_name]
clipped_position = max(min_limit, min(max_limit, position))
if clipped_position != position:
logger.debug(f"Clipped left_{motor_name} from {position:.2f}° to {clipped_position:.2f}°")
goal_pos_left[motor_name] = clipped_position
# Apply safety limits if configured
if self.config.max_relative_target is not None:
# Get current positions
present_pos_right = self.bus_right.sync_read("Present_Position")
present_pos_left = self.bus_left.sync_read("Present_Position")
# Apply safety limits to right arm
if goal_pos_right:
goal_present_pos_right = {
key: (g_pos, present_pos_right.get(key, 0.0))
for key, g_pos in goal_pos_right.items()
}
goal_pos_right = ensure_safe_goal_position(
goal_present_pos_right,
self.config.max_relative_target
)
# Apply safety limits to left arm
if goal_pos_left:
goal_present_pos_left = {
key: (g_pos, present_pos_left.get(key, 0.0))
for key, g_pos in goal_pos_left.items()
}
goal_pos_left = ensure_safe_goal_position(
goal_present_pos_left,
self.config.max_relative_target
)
# Motor name to index mapping for gains
motor_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
# Use batch MIT control for right arm (sends all commands, then collects responses)
if goal_pos_right:
commands_right = {}
for motor_name, position_degrees in goal_pos_right.items():
idx = motor_index.get(motor_name, 0)
# Use custom gains if provided, otherwise use config defaults
full_motor_name = f"right_{motor_name}"
if custom_kp is not None and full_motor_name in custom_kp:
kp = custom_kp[full_motor_name]
else:
kp = self.config.position_kp[idx] if isinstance(self.config.position_kp, list) else self.config.position_kp
if custom_kd is not None and full_motor_name in custom_kd:
kd = custom_kd[full_motor_name]
else:
kd = self.config.position_kd[idx] if isinstance(self.config.position_kd, list) else self.config.position_kd
commands_right[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
self.bus_right._mit_control_batch(commands_right)
# Use batch MIT control for left arm (sends all commands, then collects responses)
if goal_pos_left:
commands_left = {}
for motor_name, position_degrees in goal_pos_left.items():
idx = motor_index.get(motor_name, 0)
# Use custom gains if provided, otherwise use config defaults
full_motor_name = f"left_{motor_name}"
if custom_kp is not None and full_motor_name in custom_kp:
kp = custom_kp[full_motor_name]
else:
kp = self.config.position_kp[idx] if isinstance(self.config.position_kp, list) else self.config.position_kp
if custom_kd is not None and full_motor_name in custom_kd:
kd = custom_kd[full_motor_name]
else:
kd = self.config.position_kd[idx] if isinstance(self.config.position_kd, list) else self.config.position_kd
commands_left[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
self.bus_left._mit_control_batch(commands_left)
# Return the actions that were actually sent
result = {}
for motor, val in goal_pos_right.items():
result[f"right_{motor}.pos"] = val
for motor, val in goal_pos_left.items():
result[f"left_{motor}.pos"] = val
return result
def disconnect(self):
"""Disconnect from robot."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect from CAN buses
self.bus_right.disconnect(self.config.disable_torque_on_disconnect)
self.bus_left.disconnect(self.config.disable_torque_on_disconnect)
# Disconnect cameras
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")
def _deg_to_rad(self, deg: Dict[str, float | int]) -> Dict[str, float]:
"""Convert degrees to radians for all motors."""
return {m: np.deg2rad(float(v)) for m, v in deg.items()}
def _gravity_from_q(self, q_rad: Dict[str, float]) -> Dict[str, float]:
"""
Compute g(q) [N·m] for all joints in the robot.
The order of joints in the URDF matches the concatenated motor lists (right then left).
Args:
q_rad: Dictionary mapping motor names (with arm prefix) to positions in radians
Returns:
Dictionary mapping motor names to gravity torques in N·m
Raises:
RuntimeError: If URDF model is not loaded
"""
if self.pin_robot is None:
raise RuntimeError(
"Cannot compute gravity: URDF model not loaded. "
"Ensure urdf/openarms.urdf exists and is valid."
)
# Build position vector in the order of motors (left arm, then right arm)
# This order must match the URDF joint order
# URDF has: left_joint1-7, left_finger_joint1-2, right_joint1-7, right_finger_joint1-2
q = np.zeros(self.pin_robot.model.nq)
idx = 0
# Left arm motors (first in URDF) - joints 1-7
for motor_name in self.bus_left.motors:
if motor_name == "gripper":
continue # Skip gripper, will be handled separately
full_name = f"left_{motor_name}"
q[idx] = q_rad.get(full_name, 0.0)
idx += 1
# Skip left finger joints (leave as zeros)
idx += 2
# Right arm motors (second in URDF) - joints 1-7
for motor_name in self.bus_right.motors:
if motor_name == "gripper":
continue # Skip gripper, will be handled separately
full_name = f"right_{motor_name}"
q[idx] = q_rad.get(full_name, 0.0)
idx += 1
# Skip right finger joints (leave as zeros)
idx += 2
# Compute generalized gravity vector
g = pin.computeGeneralizedGravity(self.pin_robot.model, self.pin_robot.data, q)
# Map back to motor names (only arm joints, not fingers)
result = {}
idx = 0
# Left arm torques (joints 1-7)
for motor_name in self.bus_left.motors:
if motor_name == "gripper":
result["left_gripper"] = 0.0 # No gravity compensation for gripper
continue
result[f"left_{motor_name}"] = float(g[idx])
idx += 1
# Skip left finger joint torques in output
idx += 2
# Right arm torques (joints 1-7)
for motor_name in self.bus_right.motors:
if motor_name == "gripper":
result["right_gripper"] = 0.0 # No gravity compensation for gripper
continue
result[f"right_{motor_name}"] = float(g[idx])
idx += 1
# Skip right finger joint torques in output
idx += 2
return result
def _friction_from_velocity(
self,
velocity_rad_per_sec: Dict[str, float],
friction_scale: float = 1.0,
amp_tmp: float = 1.0,
coef_tmp: float = 0.1
) -> Dict[str, float]:
"""
Compute friction torques for all joints in the robot using tanh friction model.
Args:
velocity_rad_per_sec: Dictionary mapping motor names (with arm prefix) to velocities in rad/s
friction_scale: Scale factor for friction compensation (default 1.0, use 0.3 for stability)
amp_tmp: Amplitude factor for tanh term (default 1.0)
coef_tmp: Coefficient for tanh steepness (default 0.1)
Returns:
Dictionary mapping motor names to friction torques in N·m
"""
# Motor name to index mapping
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
result = {}
# Process all motors (left and right)
for motor_full_name, velocity in velocity_rad_per_sec.items():
# Extract motor name without arm prefix
if motor_full_name.startswith("right_"):
motor_name = motor_full_name.removeprefix("right_")
elif motor_full_name.startswith("left_"):
motor_name = motor_full_name.removeprefix("left_")
else:
result[motor_full_name] = 0.0
continue
# Get motor index for friction parameters
motor_index = motor_name_to_index.get(motor_name, 0)
# Get friction parameters from config
Fc = self.config.friction_fc[motor_index]
k = self.config.friction_k[motor_index]
Fv = self.config.friction_fv[motor_index]
Fo = self.config.friction_fo[motor_index]
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
friction_torque = (
amp_tmp * Fc * np.tanh(coef_tmp * k * velocity) +
Fv * velocity +
Fo
)
# Apply scale factor
friction_torque *= friction_scale
result[motor_full_name] = float(friction_torque)
return result
def get_damping_kd(self, motor_name: str) -> float:
"""
Get damping gain (Kd) for a specific motor.
Args:
motor_name: Motor name without arm prefix (e.g., "joint_1", "gripper")
Returns:
Damping gain value
"""
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
motor_index = motor_name_to_index.get(motor_name, 0)
return self.config.damping_kd[motor_index]
-41
View File
@@ -1,41 +0,0 @@
# Eun visualizer locally
# login to hf an set your access token
hf auth login
# if not installed, install with: pip install huggingface_hub
git clone https://github.com/huggingface/lerobot-dataset-visualizer.git
cd lerobot-dataset-visualizer
python -m lerobot_dataset_viz --repo-id lerobot-data-collection/repo-id-nez --episode-index 0
git checkout feat/private_repo_viz
npm install
npm run dev
# open http://localhost:3000 in your browser
# ======================================================
# default merge command; copy your list of datasets ids in repo_ids
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot-data-collection/repo-id-nez \
--operation.type merge --push_to_hub true \
--operation.repo_ids "[]"
# merge test datasets into one
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot-data-collection/test-2025-11-03-merged \
--operation.type merge --push_to_hub true \
--operation.repo_ids "['lerobot-data-collection/test-2025-11-03-13-18', 'lerobot-data-collection/test-2025-11-03-13-19', 'lerobot-data-collection/test-2025-11-03-13-20', 'lerobot-data-collection/test-2025-11-03-13-21', 'lerobot-data-collection/test-2025-11-03-13-23', 'lerobot-data-collection/test-2025-11-03-13-24', 'lerobot-data-collection/test-2025-11-03-13-25', 'lerobot-data-collection/test-2025-11-03-13-26', 'lerobot-data-collection/test-2025-11-03-13-27', 'lerobot-data-collection/test-2025-11-03-13-29', 'lerobot-data-collection/test-2025-11-03-13-30', 'lerobot-data-collection/test-2025-11-03-13-31', 'lerobot-data-collection/test-2025-11-03-13-34', 'lerobot-data-collection/test-2025-11-03-13-41', 'lerobot-data-collection/test-2025-11-03-13-42', 'lerobot-data-collection/test-2025-11-03-13-43', 'lerobot-data-collection/test-2025-11-03-13-44', 'lerobot-data-collection/test-2025-11-03-13-45', 'lerobot-data-collection/test-2025-11-03-13-46', 'lerobot-data-collection/test-2025-11-03-13-47', 'lerobot-data-collection/test-2025-11-03-13-48', 'lerobot-data-collection/test-2025-11-03-13-49']"
# RUN loop_dataset.py to get your repo_ids
# ========================================================= Two folds datasets
#merge
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot-data-collection/two-folds-dataset-full-11-04 \
--operation.type merge --push_to_hub true \
--operation.repo_ids "['lerobot-data-collection/two-folds-dataset-2025-11-04-15-06', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-08', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-10', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-11', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-12', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-14', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-16', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-18', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-20', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-22', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-24', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-25', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-27', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-28', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-29', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-33', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-34', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-35', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-36', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-52', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-53', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-54', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-55', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-56', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-57', 'lerobot-data-collection/two-folds-dataset-2025-11-04-15-59', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-00', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-01', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-02', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-03', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-04', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-05', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-06', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-07', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-08', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-09', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-26', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-28', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-29', 'lerobot-data-collection/two-folds-dataset-2025-11-04-16-30']"
+33 -1
View File
@@ -71,7 +71,7 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
@@ -94,6 +94,8 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
@@ -165,11 +167,19 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -239,6 +249,8 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -319,6 +331,8 @@ def eval_policy(
rollout_data = rollout(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
@@ -517,10 +531,16 @@ def eval_main(cfg: EvalPipelineConfig):
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -561,6 +581,8 @@ def eval_one(
env: gym.vector.VectorEnv,
*,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -576,6 +598,8 @@ def eval_one(
task_result = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -600,6 +624,8 @@ def run_one(
env,
*,
policy,
env_preprocessor,
env_postprocessor,
preprocessor,
postprocessor,
n_episodes: int,
@@ -622,6 +648,8 @@ def run_one(
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -639,6 +667,8 @@ def run_one(
def eval_policy_all(
envs: dict[str, dict[int, gym.vector.VectorEnv]],
policy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -694,6 +724,8 @@ def eval_policy_all(
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -42,7 +42,6 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
so100_leader,
so101_leader,
openarms_mini
)
COMPATIBLE_DEVICES = [
@@ -53,7 +52,6 @@ COMPATIBLE_DEVICES = [
"so101_follower",
"so101_leader",
"lekiwi",
"openarms_mini",
]
+69 -4
View File
@@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
@@ -61,6 +61,7 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
rabc_weight_computer=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -85,10 +86,22 @@ def update_policy(
"""
start_time = time.perf_counter()
policy.train()
# Compute RA-BC weights if enabled
rabc_weights = None
if rabc_weight_computer is not None:
rabc_weights = rabc_weight_computer.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# Apply RA-BC weights if enabled
if rabc_weights is not None:
# Weight the loss
loss = loss * rabc_weights.mean()
output_dict['rabc_mean_weight'] = rabc_weights.mean().item()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -140,8 +153,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
cfg.validate()
# Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
@@ -158,6 +169,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process
cfg.validate()
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
@@ -215,6 +228,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
@@ -246,6 +263,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load reward model for RA-BC if enabled
rabc_weight_computer = None
if cfg.use_rabc:
logging.info(f"Loading reward model for RA-BC from {cfg.reward_model_path}")
from lerobot.policies.factory import get_policy_class
from lerobot.utils.rabc import RABCWeightComputer
# Detect reward model type from path
# For now, assume SARM if not specified
reward_model_class = get_policy_class("sarm")
reward_model = reward_model_class.from_pretrained(cfg.reward_model_path)
reward_model.to(device)
reward_model.eval()
rabc_weight_computer = RABCWeightComputer(
reward_model=reward_model,
kappa=cfg.rabc_kappa,
epsilon=cfg.rabc_epsilon,
device=device,
)
logging.info("RA-BC weight computer initialized")
step = 0 # number of policy updates (forward + backward + optim)
@@ -259,6 +298,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
@@ -274,9 +315,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False):
# Use SARM temporal sampler for reward model training
from lerobot.datasets.temporal_sampler import SARMTemporalSampler
shuffle = False
sampler = SARMTemporalSampler(
dataset_from_index=dataset.meta.episodes["dataset_from_index"],
dataset_to_index=dataset.meta.episodes["dataset_to_index"],
frame_gap=getattr(cfg.policy, "frame_gap", 30),
shuffle=True,
seed=cfg.seed,
)
else:
shuffle = True
sampler = None
@@ -321,7 +375,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
logging.info("Start offline training on a fixed dataset")
logging.info(f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
@@ -337,6 +391,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weight_computer=rabc_weight_computer,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -353,6 +408,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weight_computer is not None:
rabc_stats = rabc_weight_computer.get_stats()
wandb_log_dict.update({
'rabc_progress_mean': rabc_stats['mean'],
'rabc_progress_std': rabc_stats['std'],
'rabc_samples_seen': rabc_stats['count'],
})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
@@ -384,6 +447,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy),
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarms_leader import OpenArmsLeaderConfig
from .openarms_leader import OpenArmsLeader
__all__ = ["OpenArmsLeader", "OpenArmsLeaderConfig"]
@@ -1,80 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Dict
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarms_leader")
@dataclass
class OpenArmsLeaderConfig(TeleoperatorConfig):
"""Configuration for the OpenArms leader/teleoperator with Damiao motors."""
# CAN interfaces - one per arm
# Right arm CAN interface (e.g., "can2")
# Left arm CAN interface (e.g., "can3")
# Linux: "can0", "can1", etc.
# macOS: "/dev/cu.usbmodem*" (serial device)
port_right: str = "can2" # CAN interface for right arm
port_left: str = "can3" # CAN interface for left arm
# CAN interface type: "socketcan" (Linux), "slcan" (macOS/serial), or "auto" (auto-detect)
can_interface: str = "socketcan"
# CAN FD settings (OpenArms uses CAN FD by default)
use_can_fd: bool = True
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
# Motor configuration for OpenArms (7 DOF per arm)
# Maps motor names to (send_can_id, recv_can_id, motor_type)
# Based on: https://docs.openarm.dev/software/setup/configure-test
# OpenArms uses 4 types of motors:
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
# - DM4340P and DM4340 for shoulder rotation and elbow
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
motor_config: Dict[str, tuple[int, int, str]] = field(default_factory=lambda: {
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
})
# Torque mode settings for manual control
# When enabled, motors have torque disabled for manual movement
manual_control: bool = True
# MIT control parameters (used when manual_control=False for torque control)
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
position_kp: list[float] = field(default_factory=lambda: [100.0, 100.0, 100.0, 48.0, 24.0, 31.0, 25.0, 16.0])
position_kd: list[float] = field(default_factory=lambda: [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
# Damping gains for stability when applying torque compensation (gravity/friction)
# Used when kp=0 and only torque is applied
damping_kd: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1, 0.1])
# Friction model parameters: τ_fric(ω) = Fo + Fv·ω + Fc·tanh(k·ω)
# From OpenArms config/leader.yaml (note: Fc[5] is slightly different: 0.083 vs 0.093)
friction_fc: list[float] = field(default_factory=lambda: [0.306, 0.306, 0.40, 0.166, 0.050, 0.083, 0.172, 0.0512]) # Coulomb friction [Nm]
friction_k: list[float] = field(default_factory=lambda: [28.417, 28.417, 29.065, 130.038, 151.771, 242.287, 7.888, 4.000]) # tanh steepness
friction_fv: list[float] = field(default_factory=lambda: [0.063, 0.0630, 0.604, 0.813, 0.029, 0.072, 0.084, 0.084]) # Viscous friction [Nm·s/rad]
friction_fo: list[float] = field(default_factory=lambda: [0.088, 0.088, 0.008, -0.058, 0.005, 0.009, -0.059, -0.050]) # Offset torque [Nm]
@@ -1,503 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any, Dict
import numpy as np
import pinocchio as pin
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.motors.damiao.tables import MotorType
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_openarms_leader import OpenArmsLeaderConfig
logger = logging.getLogger(__name__)
class OpenArmsLeader(Teleoperator):
"""
OpenArms Leader/Teleoperator Arm with Damiao motors.
This teleoperator uses CAN bus communication to read positions from
Damiao motors that are manually moved (torque disabled).
"""
config_class = OpenArmsLeaderConfig
name = "openarms_leader"
def __init__(self, config: OpenArmsLeaderConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES # Always use degrees for Damiao motors
# Right arm motors (on port_right)
# Each arm uses the same CAN IDs since they're on separate buses
motors_right = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(send_id, motor_type_str, norm_mode_body)
motor.recv_id = recv_id
motor.motor_type = getattr(MotorType, motor_type_str.upper().replace("-", "_"))
motors_right[motor_name] = motor
# Left arm motors (on port_left, same IDs as right since separate bus)
motors_left = {}
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
motor = Motor(send_id, motor_type_str, norm_mode_body)
motor.recv_id = recv_id
motor.motor_type = getattr(MotorType, motor_type_str.upper().replace("-", "_"))
motors_left[motor_name] = motor
# Initialize separate Damiao motors buses (one per arm) with CAN FD support
self.bus_right = DamiaoMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration={k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")},
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
bitrate=self.config.can_bitrate,
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
self.bus_left = DamiaoMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration={k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")},
can_interface=self.config.can_interface,
use_can_fd=self.config.use_can_fd,
bitrate=self.config.can_bitrate,
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
)
# Initialize Pinocchio robot model for dynamics (optional)
self.pin_robot = None
try:
# Load URDF - try external path first (with meshes), then repository
import os
from os.path import expanduser, dirname
# Try external URDF with meshes first
external_urdf_path = expanduser("~/Documents/openarm_description/openarm_bimanual_pybullet.urdf")
if os.path.exists(external_urdf_path):
urdf_path = external_urdf_path
urdf_dir = dirname(urdf_path)
self.pin_robot = pin.RobotWrapper.BuildFromURDF(urdf_path, urdf_dir)
self.pin_robot.data = self.pin_robot.model.createData()
logger.info(f"Loaded OpenArms URDF for dynamics computation from {urdf_path}")
except Exception as e:
logger.warning(f"Could not load URDF for dynamics: {e}. Gravity compensation will not be available.")
@property
def action_features(self) -> Dict[str, type]:
"""Features produced by this teleoperator."""
features = {}
# Right arm motors - only positions stored in dataset
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
# Left arm motors - only positions stored in dataset
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> Dict[str, type]:
"""Feedback features (not implemented for OpenArms)."""
return {}
@property
def is_connected(self) -> bool:
"""Check if teleoperator is connected."""
return self.bus_right.is_connected and self.bus_left.is_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the teleoperator.
For manual control, we disable torque after connecting so the
arm can be moved by hand.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN buses
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
# Run calibration if needed
if calibrate:
logger.info(
"No calibration found or overwriting calibration. Running calibration..."
)
self.calibrate()
# Configure for manual control
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
"""Check if teleoperator is calibrated."""
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArms leader.
The calibration procedure:
1. Disable torque (if not already disabled)
2. Ask user to position arm in zero position (hanging with gripper closed)
3. Set this as zero position
4. Record range of motion for each joint
5. Save calibration
"""
if self.calibration:
# Ask user whether to use existing calibration
user_input = input(
f"Press ENTER to use existing calibration for {self.id}, "
f"or type 'c' and press ENTER to run new calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
# Split calibration for each bus
cal_right = {k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")}
cal_left = {k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
# Calibrate each arm separately
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
def _calibrate_arm(self, arm_name: str, bus: DamiaoMotorsBus) -> None:
"""Calibrate a single arm."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
# Ensure torque is disabled for manual positioning
bus.disable_torque()
time.sleep(0.1)
# Step 1: Set zero position
input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
# Set current position as zero for all motors
bus.set_zero_position()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
# Automatically set range to -90° to +90° for all joints
print(
f"\nAutomatically setting range: -90° to +90° for all joints"
)
# Create calibration data with fixed ranges
if self.calibration is None:
self.calibration = {}
for motor_name, motor in bus.motors.items():
# Prefix motor name with arm name for storage
prefixed_name = f"{arm_name}_{motor_name}"
# Use -90 to +90 for all joints and gripper (integers required)
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=0, # Normal direction
homing_offset=0, # Already set via set_zero_position
range_min=-90, # -90 degrees (integer)
range_max=90, # +90 degrees (integer)
)
logger.info(f" {prefixed_name}: range set to [-90°, +90°]")
# Write calibration to this arm's motors
cal_for_bus = {k.replace(f"{arm_name}_", ""): v for k, v in self.calibration.items() if k.startswith(f"{arm_name}_")}
bus.write_calibration(cal_for_bus)
# Save calibration after each arm
self._save_calibration()
def configure(self) -> None:
"""
Configure motors for manual teleoperation.
For manual control, we disable torque so the arm can be moved by hand.
"""
if self.config.manual_control:
# Disable torque for manual control
logger.info("Disabling torque for manual control...")
self.bus_right.disable_torque()
self.bus_left.disable_torque()
else:
# Configure motors normally
self.bus_right.configure_motors()
self.bus_left.configure_motors()
def setup_motors(self) -> None:
raise NotImplementedError("Motor ID configuration is typically done via manufacturer tools for CAN motors.")
def get_action(self) -> Dict[str, Any]:
"""
Get current action from the leader arm.
This is the main method for teleoperators - it reads the current state
of the leader arm and returns it as an action that can be sent to a follower.
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
Note: Velocity and torque are read but not stored in dataset (only used for
gravity/friction compensation during recording).
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
action_dict = {}
start = time.perf_counter()
# OPTIMIZED: Use sync_read_all_states to get pos/vel/torque in one go
right_states = self.bus_right.sync_read_all_states()
for motor in self.bus_right.motors:
state = right_states.get(motor, {})
action_dict[f"right_{motor}.pos"] = state.get("position", 0.0)
action_dict[f"right_{motor}.vel"] = state.get("velocity", 0.0)
action_dict[f"right_{motor}.torque"] = state.get("torque", 0.0)
# OPTIMIZED: Use sync_read_all_states to get pos/vel/torque in one go
left_states = self.bus_left.sync_read_all_states()
for motor in self.bus_left.motors:
state = left_states.get(motor, {})
action_dict[f"left_{motor}.pos"] = state.get("position", 0.0)
action_dict[f"left_{motor}.vel"] = state.get("velocity", 0.0)
action_dict[f"left_{motor}.torque"] = state.get("torque", 0.0)
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
return action_dict
def send_feedback(self, feedback: Dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArms leader.")
def disconnect(self) -> None:
"""Disconnect from teleoperator."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# For manual control, ensure torque is disabled before disconnecting
if self.config.manual_control:
try:
self.bus_right.disable_torque()
self.bus_left.disable_torque()
except Exception as e:
logger.warning(f"Failed to disable torque during disconnect: {e}")
# Disconnect from CAN buses
self.bus_right.disconnect(disable_torque=False) # Already disabled above if needed
self.bus_left.disconnect(disable_torque=False)
logger.info(f"{self} disconnected.")
def _deg_to_rad(self, deg: Dict[str, float | int]) -> Dict[str, float]:
"""Convert degrees to radians for all motors."""
return {m: np.deg2rad(float(v)) for m, v in deg.items()}
def _gravity_from_q(self, q_rad: Dict[str, float]) -> Dict[str, float]:
"""
Compute g(q) [N·m] for all joints in the robot.
The order of joints in the URDF matches the concatenated motor lists (right then left).
Args:
q_rad: Dictionary mapping motor names (with arm prefix) to positions in radians
Returns:
Dictionary mapping motor names to gravity torques in N·m
Raises:
RuntimeError: If URDF model is not loaded
"""
if self.pin_robot is None:
raise RuntimeError(
"Cannot compute gravity: URDF model not loaded. "
"Ensure urdf/openarms.urdf exists and is valid."
)
# Build position vector in the order of motors (left arm, then right arm)
# This order must match the URDF joint order
# URDF has: left_joint1-7, left_finger_joint1-2, right_joint1-7, right_finger_joint1-2
q = np.zeros(self.pin_robot.model.nq)
idx = 0
# Left arm motors (first in URDF) - joints 1-7
for motor_name in self.bus_left.motors:
if motor_name == "gripper":
continue # Skip gripper, will be handled separately
full_name = f"left_{motor_name}"
q[idx] = q_rad.get(full_name, 0.0)
idx += 1
# Skip left finger joints (leave as zeros)
idx += 2
# Right arm motors (second in URDF) - joints 1-7
for motor_name in self.bus_right.motors:
if motor_name == "gripper":
continue # Skip gripper, will be handled separately
full_name = f"right_{motor_name}"
q[idx] = q_rad.get(full_name, 0.0)
idx += 1
# Skip right finger joints (leave as zeros)
idx += 2
# Compute generalized gravity vector
g = pin.computeGeneralizedGravity(self.pin_robot.model, self.pin_robot.data, q)
# Map back to motor names (only arm joints, not fingers)
result = {}
idx = 0
# Left arm torques (joints 1-7)
for motor_name in self.bus_left.motors:
if motor_name == "gripper":
result["left_gripper"] = 0.0 # No gravity compensation for gripper
continue
result[f"left_{motor_name}"] = float(g[idx])
idx += 1
# Skip left finger joint torques in output
idx += 2
# Right arm torques (joints 1-7)
for motor_name in self.bus_right.motors:
if motor_name == "gripper":
result["right_gripper"] = 0.0 # No gravity compensation for gripper
continue
result[f"right_{motor_name}"] = float(g[idx])
idx += 1
# Skip right finger joint torques in output
idx += 2
return result
def _friction_from_velocity(
self,
velocity_rad_per_sec: Dict[str, float],
friction_scale: float = 1.0,
amp_tmp: float = 1.0,
coef_tmp: float = 0.1
) -> Dict[str, float]:
"""
Compute friction torques for all joints in the robot using tanh friction model.
Args:
velocity_rad_per_sec: Dictionary mapping motor names (with arm prefix) to velocities in rad/s
friction_scale: Scale factor for friction compensation (default 1.0, use 0.3 for stability)
amp_tmp: Amplitude factor for tanh term (default 1.0)
coef_tmp: Coefficient for tanh steepness (default 0.1)
Returns:
Dictionary mapping motor names to friction torques in N·m
"""
# Motor name to index mapping
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
result = {}
# Process all motors (left and right)
for motor_full_name, velocity in velocity_rad_per_sec.items():
# Extract motor name without arm prefix
if motor_full_name.startswith("right_"):
motor_name = motor_full_name.removeprefix("right_")
elif motor_full_name.startswith("left_"):
motor_name = motor_full_name.removeprefix("left_")
else:
result[motor_full_name] = 0.0
continue
# Get motor index for friction parameters
motor_index = motor_name_to_index.get(motor_name, 0)
# Get friction parameters from config
Fc = self.config.friction_fc[motor_index]
k = self.config.friction_k[motor_index]
Fv = self.config.friction_fv[motor_index]
Fo = self.config.friction_fo[motor_index]
# Friction model: τ_fric = amp * Fc * tanh(coef * k * ω) + Fv * ω + Fo
friction_torque = (
amp_tmp * Fc * np.tanh(coef_tmp * k * velocity) +
Fv * velocity +
Fo
)
# Apply scale factor
friction_torque *= friction_scale
result[motor_full_name] = float(friction_torque)
return result
def get_damping_kd(self, motor_name: str) -> float:
"""
Get damping gain (Kd) for a specific motor.
Args:
motor_name: Motor name without arm prefix (e.g., "joint_1", "gripper")
Returns:
Damping gain value
"""
motor_name_to_index = {
"joint_1": 0,
"joint_2": 1,
"joint_3": 2,
"joint_4": 3,
"joint_5": 4,
"joint_6": 5,
"joint_7": 6,
"gripper": 7,
}
motor_index = motor_name_to_index.get(motor_name, 0)
return self.config.damping_kd[motor_index]
@@ -1,33 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..teleoperator import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarms_mini")
@dataclass
class OpenArmsMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArms Mini teleoperator with Feetech motors (dual arms)."""
# Serial ports for left and right arms
port_right: str = "/dev/ttyUSB0" # Serial port for right arm
port_left: str = "/dev/ttyUSB1" # Serial port for left arm
# Whether to use degrees mode (True) or normalized mode (False)
use_degrees: bool = True

Some files were not shown because too many files have changed in this diff Show More