diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml
index d16fe5e72..0155eec13 100644
--- a/.github/workflows/full_tests.yml
+++ b/.github/workflows/full_tests.yml
@@ -78,7 +78,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lerobot with all extras
- run: uv sync --all-extras
+ run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml
index f9fa02597..be8b5c094 100644
--- a/.github/workflows/nightly.yml
+++ b/.github/workflows/nightly.yml
@@ -189,5 +189,6 @@ jobs:
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
- run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
+ # TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
+ run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 67aa5186b..4891707ac 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -82,6 +82,14 @@ jobs:
exit 1
fi
+ - name: Remove Tags with Git dependencies
+ # TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
+ run: |
+ echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
+ grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
+ sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
+ echo "::info:: Git dependencies removed. Proceeding with build."
+
- name: Install build dependencies
run: python -m pip install build
@@ -103,7 +111,7 @@ jobs:
- name: Publish to TestPyPI for pre-releases
# True for tags like 'v0.2.0-rc1'
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
- uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
+ uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
repository-url: https://test.pypi.org/legacy/
verbose: true
@@ -111,7 +119,7 @@ jobs:
- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
- uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
+ uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
verbose: true
print-hash: true
@@ -138,7 +146,7 @@ jobs:
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:
- enable-cache: true
+ enable-cache: true # zizmor: ignore[cache-poisoning]
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- name: Create uv virtual environment
diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml
index 902074a83..b6fc5ea2f 100644
--- a/.github/workflows/unbound_deps_tests.yml
+++ b/.github/workflows/unbound_deps_tests.yml
@@ -70,7 +70,7 @@ jobs:
echo "Dependencies unbound:" && cat pyproject.toml
- name: Install lerobot with all extras
- run: uv sync --all-extras
+ run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7f5beff80..896a0c10b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,7 +26,7 @@ repos:
##### General Code Quality & Formatting #####
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v5.0.0
+ rev: v6.0.0
hooks:
- id: check-added-large-files
args: ['--maxkb=1024']
@@ -39,20 +39,20 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.12.4
+ rev: v0.14.1
hooks:
- id: ruff-format
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/adhtruong/mirrors-typos
- rev: v1.34.0
+ rev: v1.38.1
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
- rev: v3.20.0
+ rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
@@ -68,12 +68,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
- rev: v8.27.2
+ rev: v8.28.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
- rev: v1.11.0
+ rev: v1.15.2
hooks:
- id: zizmor
@@ -87,7 +87,7 @@ repos:
# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.16.0
+ rev: v1.18.2
hooks:
- id: mypy
args: [--config-file=pyproject.toml]
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a07596728..dcb5c03d8 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -137,7 +137,7 @@ Follow these steps to start contributing:
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
- Set up a development environment with conda or miniconda:
+ Set up a development environment with conda:
```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
diff --git a/README.md b/README.md
index 6409d931d..964af4c1d 100644
--- a/README.md
+++ b/README.md
@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup
-Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
+Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
-When using `miniconda`, install `ffmpeg` in your environment:
+When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -185,6 +185,11 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
+> [!NOTE]
+> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
+>
+> This will be solved in the next patch release
+
### Weights & Biases
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
@@ -207,13 +212,13 @@ lerobot-dataset-viz \
--episode-index 0
```
-or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
+or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
- --local-files-only 1 \
+ --mode local \
--episode-index 0
```
@@ -337,7 +342,3 @@ If you want, you can cite this work with:
## Star History
[](https://star-history.com/#huggingface/lerobot&Timeline)
-
-```
-
-```
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 5e100013a..765f6698d 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -15,8 +15,6 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- - local: async
- title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials"
@@ -37,8 +35,18 @@
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
+ - local: groot
+ title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
+ - local: async
+ title: Use Async Inference
+ - local: rtc
+ title: Real-Time Chunking (RTC)
+ title: "Inference"
+- sections:
+ - local: envhub
+ title: Environments from the Hub
- local: il_sim
title: Imitation Learning in Sim
- local: libero
@@ -55,6 +63,8 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
+ - local: env_processor
+ title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101
diff --git a/docs/source/env_processor.mdx b/docs/source/env_processor.mdx
new file mode 100644
index 000000000..8dbf315c7
--- /dev/null
+++ b/docs/source/env_processor.mdx
@@ -0,0 +1,418 @@
+# Environment Processors
+
+Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
+
+## Why Environment Processors?
+
+When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
+
+1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
+2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
+3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
+
+Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
+
+## The Processing Pipeline
+
+Here's how data flows through the complete processing pipeline during evaluation:
+
+```python
+# In lerobot_eval.py rollout() function:
+
+# 1. Raw environment observation (numpy arrays, various formats)
+raw_observation = env.step(action)
+
+# 2. Convert numpy to torch, normalize images [0,1]
+observation = preprocess_observation(raw_observation)
+
+# 3. Add task metadata (for multi-task environments)
+observation = add_envs_task(env, observation)
+
+# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
+# - Flatten robot states
+# - Rotate images to match dataset conventions
+# - Handle environment-specific coordinate systems
+observation = env_preprocessor(observation)
+
+# 5. POLICY-SPECIFIC preprocessing
+# - Normalize with dataset statistics
+# - Add batch dimensions
+# - Move to GPU
+# - Tokenize language instructions
+observation = preprocessor(observation)
+
+# 6. Policy inference
+action = policy.select_action(observation)
+
+# 7. POLICY-SPECIFIC postprocessing
+# - Unnormalize actions
+# - Remove batch dimensions
+action = postprocessor(action)
+
+# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
+# - Convert action formats if needed
+# - Apply environment-specific constraints
+action_transition = {"action": action}
+action_transition = env_postprocessor(action_transition)
+action = action_transition["action"]
+
+# 9. Execute in environment
+env.step(action)
+```
+
+## The Benefits
+
+### 1. **Separation of Concerns**
+
+Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
+
+```python
+# ❌ Before: Mixed concerns
+class LiberoVLAPolicy:
+ def preprocess(self, obs):
+ # Environment-specific: Flatten robot state (shouldn't be in policy!)
+ state = self._flatten_robot_state(obs["robot_state"])
+ # Policy-specific: Normalize with dataset stats
+ state = self.normalizer(state)
+ return state
+
+# ✅ After: Clear separation
+# Environment processor: Handles LIBERO's nested robot state
+env_preprocessor = LiberoProcessorStep() # Flattens robot_state
+
+# Policy processor: Handles model requirements
+policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
+```
+
+### 2. **Flexibility and Reusability**
+
+The same policy can work with different environment processors, and the same environment processor can work with different policies:
+
+```python
+# Use SmolVLA policy with LIBERO environment
+libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
+smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
+
+# Or use ACT policy with the same LIBERO environment
+libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
+act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
+```
+
+### 3. **Easier Experimentation**
+
+Want to try different state representations for LIBERO? Just create a new processor:
+
+```python
+# Original: 8D state (pos + quat→axisangle + gripper)
+@ProcessorStepRegistry.register("libero_processor")
+class LiberoProcessorStep(ObservationProcessorStep):
+ def _process_observation(self, obs):
+ eef_pos = robot_state["eef"]["pos"] # 3D
+ eef_axisangle = quat2axisangle(quat) # 3D
+ gripper = robot_state["gripper"]["qpos"] # 2D
+ state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
+ return state
+
+# Experiment: Add velocity for better control
+@ProcessorStepRegistry.register("libero_velocity_processor")
+class LiberoVelocityProcessorStep(ObservationProcessorStep):
+ def _process_observation(self, obs):
+ # Include velocities for 14D state
+ eef_pos = robot_state["eef"]["pos"] # 3D
+ eef_axisangle = quat2axisangle(quat) # 3D
+ eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
+ gripper_pos = robot_state["gripper"]["qpos"] # 2D
+ gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
+ state = torch.cat([eef_pos, eef_axisangle, eef_vel,
+ gripper_pos, gripper_vel], dim=-1) # 14D
+ return state
+```
+
+### 4. **Cleaner Environment Code**
+
+Environments expose **all available data** without needing to know what downstream models will use:
+
+```python
+# LIBERO environment exposes full robot state
+observation = {
+ "pixels": {"image": img, "image2": img2},
+ "robot_state": {
+ "eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
+ "gripper": {"qpos": ..., "qvel": ...},
+ "joints": {"pos": ..., "vel": ...}
+ }
+}
+
+# Environment processor decides what to use
+# Policy processor handles model-specific transformations
+```
+
+## Using Environment Processors
+
+### Factory Function
+
+The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
+
+```python
+from lerobot.envs.factory import make_env_pre_post_processors
+from lerobot.envs.configs import LiberoEnv, PushtEnv
+
+# For LIBERO: Returns LiberoProcessorStep in preprocessor
+libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
+env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
+
+# For other environments: Returns identity processors (no-op)
+pusht_cfg = PushtEnv()
+env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
+```
+
+### Implementation in `envs/factory.py`
+
+```python
+def make_env_pre_post_processors(
+ env_cfg: EnvConfig,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+]:
+ """
+ Create preprocessor and postprocessor pipelines for environment observations.
+
+ Args:
+ env_cfg: The configuration of the environment.
+
+ Returns:
+ A tuple containing:
+ - preprocessor: Pipeline that processes environment observations
+ - postprocessor: Pipeline that processes environment outputs
+ """
+ # For LIBERO environments, add the LiberoProcessorStep to preprocessor
+ if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
+ preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
+ else:
+ # For all other environments, return an identity preprocessor
+ preprocessor = PolicyProcessorPipeline(steps=[])
+
+ # Postprocessor is currently identity for all environments
+ # Future: Could add environment-specific action transformations
+ postprocessor = PolicyProcessorPipeline(steps=[])
+
+ return preprocessor, postprocessor
+```
+
+### Integration in Evaluation
+
+In `lerobot_eval.py`, the environment processors are created once and used throughout:
+
+```python
+def eval_main(cfg: EvalPipelineConfig):
+ # Create environment
+ envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
+
+ # Create policy
+ policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
+
+ # Create policy processors
+ preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=cfg.policy,
+ pretrained_path=cfg.policy.pretrained_path,
+ )
+
+ # Create environment processors (NEW!)
+ env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
+
+ # Run evaluation with both processor types
+ eval_policy_all(
+ envs=envs,
+ policy=policy,
+ env_preprocessor=env_preprocessor, # Environment-specific
+ env_postprocessor=env_postprocessor, # Environment-specific
+ preprocessor=preprocessor, # Policy-specific
+ postprocessor=postprocessor, # Policy-specific
+ n_episodes=cfg.eval.n_episodes,
+ )
+```
+
+## Example: LIBERO Environment Processor
+
+The `LiberoProcessorStep` demonstrates a real-world environment processor:
+
+```python
+from lerobot.processor.pipeline import ObservationProcessorStep
+
+@dataclass
+@ProcessorStepRegistry.register(name="libero_processor")
+class LiberoProcessorStep(ObservationProcessorStep):
+ """
+ Processes LIBERO observations into the LeRobot format.
+
+ **State Processing:**
+ - Extracts end-effector position (3D)
+ - Converts quaternion to axis-angle representation (3D)
+ - Extracts gripper joint positions (2D)
+ - Concatenates into 8D state vector
+
+ **Image Processing:**
+ - Rotates images 180° to match HuggingFaceVLA/libero convention
+ """
+
+ def _process_observation(self, observation):
+ processed_obs = observation.copy()
+
+ # Process images: Flip 180° for camera convention
+ for key in list(processed_obs.keys()):
+ if key.startswith("observation.images."):
+ img = processed_obs[key]
+ img = torch.flip(img, dims=[2, 3]) # Flip H and W
+ processed_obs[key] = img
+
+ # Process robot_state: Flatten to 8D vector
+ if "observation.robot_state" in processed_obs:
+ robot_state = processed_obs.pop("observation.robot_state")
+
+ eef_pos = robot_state["eef"]["pos"] # (B, 3)
+ eef_quat = robot_state["eef"]["quat"] # (B, 4)
+ gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
+
+ # Convert quaternion to axis-angle
+ eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
+
+ # Concatenate into single state vector
+ state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
+ state = state.float()
+
+ processed_obs["observation.state"] = state
+
+ return processed_obs
+```
+
+### Why These Transformations?
+
+1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
+
+2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
+ - Selects the relevant components (pos, quat, gripper)
+ - Converts quaternion to axis-angle (more suitable for learning)
+ - Flattens to a single 8D vector that policies expect
+
+3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
+
+## Adding Environment Processors for New Environments
+
+To add environment processors for a new environment:
+
+### 1. Create the Processor Step
+
+```python
+# In src/lerobot/processor/env_processor.py
+
+@dataclass
+@ProcessorStepRegistry.register(name="myenv_processor")
+class MyEnvProcessorStep(ObservationProcessorStep):
+ """Process observations from MyEnv."""
+
+ def _process_observation(self, observation):
+ processed = observation.copy()
+
+ # Your environment-specific transformations
+ if "myenv.specific.state" in processed:
+ state = processed.pop("myenv.specific.state")
+ # Transform to standard format
+ processed["observation.state"] = self._transform_state(state)
+
+ return processed
+```
+
+### 2. Update the Factory
+
+```python
+# In src/lerobot/envs/factory.py
+
+def make_env_pre_post_processors(env_cfg: EnvConfig):
+ if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
+ preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
+ elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
+ preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
+ else:
+ preprocessor = PolicyProcessorPipeline(steps=[])
+
+ postprocessor = PolicyProcessorPipeline(steps=[])
+ return preprocessor, postprocessor
+```
+
+### 3. Use in Evaluation
+
+No changes needed! The evaluation script automatically uses the appropriate processor:
+
+```bash
+lerobot-eval \
+ --policy.path=lerobot/my_policy \
+ --env.type=myenv \ # Automatically uses MyEnvProcessorStep
+ --eval.n_episodes=10
+```
+
+## Future: Environment Postprocessors
+
+Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
+
+### Action Space Transformations
+
+```python
+@dataclass
+class MyEnvActionPostprocessor(ProcessorStep):
+ """Convert policy actions to environment-specific format."""
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ action = transition["action"]
+
+ # Example: Convert from Cartesian to joint space
+ if self.action_space == "joint":
+ action = self.ik_solver(action)
+
+ # Example: Apply environment-specific safety limits
+ action = torch.clamp(action, self.min_action, self.max_action)
+
+ transition["action"] = action
+ return transition
+```
+
+### Coordinate System Conversions
+
+```python
+@dataclass
+class CoordinateTransformPostprocessor(ProcessorStep):
+ """Transform actions between coordinate systems."""
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ action = transition["action"]
+
+ # Example: Policy outputs in world frame, env expects base frame
+ action = self.world_to_base_transform(action)
+
+ transition["action"] = action
+ return transition
+```
+
+## Best Practices
+
+1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
+
+2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
+
+3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
+
+4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
+
+5. **Test independently**: Environment processors should be testable without loading full policies or environments.
+
+## Summary
+
+Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
+
+- ✅ Enables easy experimentation with different state representations
+- ✅ Allows policies to work seamlessly across different environments
+- ✅ Keeps environment code focused on simulation/hardware interface
+- ✅ Makes processor pipelines more maintainable and debuggable
+- ✅ Follows the single responsibility principle
+
+The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx
new file mode 100644
index 000000000..ba6464460
--- /dev/null
+++ b/docs/source/envhub.mdx
@@ -0,0 +1,424 @@
+# Loading Environments from the Hub
+
+The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
+
+## Overview
+
+With EnvHub, you can:
+
+- Load environments from the Hub instantly
+- Share your custom simulation tasks with the community
+- Version control your environments using Git
+- Distribute complex physics simulations without packaging hassles
+
+## Quick Start
+
+Loading an environment from the Hub is as simple as:
+
+```python
+from lerobot.envs.factory import make_env
+
+# Load a hub environment (requires explicit consent to run remote code)
+env = make_env("lerobot/cartpole-env", trust_remote_code=True)
+```
+
+
+ **Security Notice**: Loading environments from the Hub executes Python code
+ from third-party repositories. Only use `trust_remote_code=True` with
+ repositories you trust. We strongly recommend pinning to a specific commit
+ hash for reproducibility and security.
+
+
+## What is EnvHub?
+
+EnvHub is a framework that allows researchers and developers to:
+
+1. **Publish environments** to the Hugging Face Hub as Git repositories
+2. **Load environments** dynamically without installing them as packages
+3. **Version and track** environment changes using Git semantics
+4. **Discover** new simulation tasks shared by the community
+
+This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
+
+## Repository Structure
+
+To make your environment loadable from the Hub, your repository must contain at minimum:
+
+### Required Files
+
+**`env.py`** (or custom Python file)
+
+- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
+- This function should return one of:
+ - A `gym.vector.VectorEnv` (most common)
+ - A single `gym.Env` (will be automatically wrapped)
+ - A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
+
+### Optional Files
+
+**`requirements.txt`**
+
+- List any additional dependencies your environment needs
+- Users will need to install these manually before loading your environment
+
+**`README.md`**
+
+- Document your environment: what task it implements, observation/action spaces, rewards, etc.
+- Include usage examples and any special setup instructions
+
+**`.gitignore`**
+
+- Exclude unnecessary files from your repository
+
+### Example Repository Structure
+
+```
+my-environment-repo/
+├── env.py # Main environment definition (required)
+├── requirements.txt # Dependencies (optional)
+├── README.md # Documentation (recommended)
+├── assets/ # Images, videos, etc. (optional)
+│ └── demo.gif
+└── configs/ # Config files if needed (optional)
+ └── task_config.yaml
+```
+
+## Creating Your Environment Repository
+
+### Step 1: Define Your Environment
+
+Create an `env.py` file with a `make_env` function:
+
+```python
+# env.py
+import gymnasium as gym
+
+def make_env(n_envs: int = 1, use_async_envs: bool = False):
+ """
+ Create vectorized environments for your custom task.
+
+ Args:
+ n_envs: Number of parallel environments
+ use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv
+
+ Returns:
+ gym.vector.VectorEnv or dict mapping suite names to vectorized envs
+ """
+ def _make_single_env():
+ # Create your custom environment
+ return gym.make("CartPole-v1")
+
+ # Choose vector environment type
+ env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
+
+ # Create vectorized environment
+ vec_env = env_cls([_make_single_env for _ in range(n_envs)])
+
+ return vec_env
+```
+
+### Step 2: Test Locally
+
+Before uploading, test your environment locally:
+
+```python
+from lerobot.envs.utils import _load_module_from_path, _call_make_env, _normalize_hub_result
+
+# Load your module
+module = _load_module_from_path("./env.py")
+
+# Test the make_env function
+result = _call_make_env(module, n_envs=2, use_async_envs=False)
+normalized = _normalize_hub_result(result)
+
+# Verify it works
+suite_name = next(iter(normalized))
+env = normalized[suite_name][0]
+obs, info = env.reset()
+print(f"Observation shape: {obs.shape if hasattr(obs, 'shape') else type(obs)}")
+env.close()
+```
+
+### Step 3: Upload to the Hub
+
+Upload your repository to Hugging Face:
+
+```bash
+# Install huggingface_hub if needed
+pip install huggingface_hub
+
+# Login to Hugging Face
+huggingface-cli login
+
+# Create a new repository
+huggingface-cli repo create my-custom-env --type space --org my-org
+
+# Initialize git and push
+git init
+git add .
+git commit -m "Initial environment implementation"
+git remote add origin https://huggingface.co/my-org/my-custom-env
+git push -u origin main
+```
+
+Alternatively, use the `huggingface_hub` Python API:
+
+```python
+from huggingface_hub import HfApi
+
+api = HfApi()
+
+# Create repository
+api.create_repo("my-custom-env", repo_type="space")
+
+# Upload files
+api.upload_folder(
+ folder_path="./my-env-folder",
+ repo_id="username/my-custom-env",
+ repo_type="space",
+)
+```
+
+## Loading Environments from the Hub
+
+### Basic Usage
+
+```python
+from lerobot.envs.factory import make_env
+
+# Load from the hub
+envs_dict = make_env(
+ "username/my-custom-env",
+ n_envs=4,
+ trust_remote_code=True
+)
+
+# Access the environment
+suite_name = next(iter(envs_dict))
+env = envs_dict[suite_name][0]
+
+# Use it like any gym environment
+obs, info = env.reset()
+action = env.action_space.sample()
+obs, reward, terminated, truncated, info = env.step(action)
+```
+
+### Advanced: Pinning to Specific Versions
+
+For reproducibility and security, pin to a specific Git revision:
+
+```python
+# Pin to a specific branch
+env = make_env("username/my-env@main", trust_remote_code=True)
+
+# Pin to a specific commit (recommended for papers/experiments)
+env = make_env("username/my-env@abc123def456", trust_remote_code=True)
+
+# Pin to a tag
+env = make_env("username/my-env@v1.0.0", trust_remote_code=True)
+```
+
+### Custom File Paths
+
+If your environment definition is not in `env.py`:
+
+```python
+# Load from a custom file
+env = make_env("username/my-env:custom_env.py", trust_remote_code=True)
+
+# Combine with version pinning
+env = make_env("username/my-env@v1.0:envs/task_a.py", trust_remote_code=True)
+```
+
+### Async Environments
+
+For better performance with multiple environments:
+
+```python
+envs_dict = make_env(
+ "username/my-env",
+ n_envs=8,
+ use_async_envs=True, # Use AsyncVectorEnv for parallel execution
+ trust_remote_code=True
+)
+```
+
+## URL Format Reference
+
+The hub URL format supports several patterns:
+
+| Pattern | Description | Example |
+| -------------------- | ------------------------------ | -------------------------------------- |
+| `user/repo` | Load `env.py` from main branch | `make_env("lerobot/pusht-env")` |
+| `user/repo@revision` | Load from specific revision | `make_env("lerobot/pusht-env@main")` |
+| `user/repo:path` | Load custom file | `make_env("lerobot/envs:pusht.py")` |
+| `user/repo@rev:path` | Revision + custom file | `make_env("lerobot/envs@v1:pusht.py")` |
+
+## Multi-Task Environments
+
+For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
+
+```python
+def make_env(n_envs: int = 1, use_async_envs: bool = False):
+ env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
+
+ # Return dict: {suite_name: {task_id: VectorEnv}}
+ return {
+ "suite_1": {
+ 0: env_cls([lambda: gym.make("Task1-v0") for _ in range(n_envs)]),
+ 1: env_cls([lambda: gym.make("Task2-v0") for _ in range(n_envs)]),
+ },
+ "suite_2": {
+ 0: env_cls([lambda: gym.make("Task3-v0") for _ in range(n_envs)]),
+ }
+ }
+```
+
+## Security Considerations
+
+
+ **Important**: The `trust_remote_code=True` flag is required to execute
+ environment code from the Hub. This is by design for security.
+
+
+When loading environments from the Hub:
+
+1. **Review the code first**: Visit the repository and inspect `env.py` before loading
+2. **Pin to commits**: Use specific commit hashes for reproducibility
+3. **Check dependencies**: Review `requirements.txt` for suspicious packages
+4. **Use trusted sources**: Prefer official organizations or well-known researchers
+5. **Sandbox if needed**: Run untrusted code in isolated environments (containers, VMs)
+
+Example of safe usage:
+
+```python
+# ❌ BAD: Loading without inspection
+env = make_env("random-user/untrusted-env", trust_remote_code=True)
+
+# ✅ GOOD: Review code, then pin to specific commit
+# 1. Visit https://huggingface.co/trusted-org/verified-env
+# 2. Review the env.py file
+# 3. Copy the commit hash
+env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True)
+```
+
+## Example: CartPole from the Hub
+
+Here's a complete example using the reference CartPole environment:
+
+```python
+from lerobot.envs.factory import make_env
+import numpy as np
+
+# Load the environment
+envs_dict = make_env("lerobot/cartpole-env", n_envs=4, trust_remote_code=True)
+
+# Get the vectorized environment
+suite_name = next(iter(envs_dict))
+env = envs_dict[suite_name][0]
+
+# Run a simple episode
+obs, info = env.reset()
+done = np.zeros(env.num_envs, dtype=bool)
+total_reward = np.zeros(env.num_envs)
+
+while not done.all():
+ # Random policy
+ action = env.action_space.sample()
+ obs, reward, terminated, truncated, info = env.step(action)
+ total_reward += reward
+ done = terminated | truncated
+
+print(f"Average reward: {total_reward.mean():.2f}")
+env.close()
+```
+
+## Benefits of EnvHub
+
+### For Environment Authors
+
+- **Easy distribution**: No PyPI packaging required
+- **Version control**: Use Git for environment versioning
+- **Rapid iteration**: Push updates instantly
+- **Documentation**: Hub README renders beautifully
+- **Community**: Reach LeRobot users directly
+
+### For Researchers
+
+- **Quick experiments**: Load any environment in one line
+- **Reproducibility**: Pin to specific commits
+- **Discovery**: Browse environments on the Hub
+- **No conflicts**: No need to install conflicting packages
+
+### For the Community
+
+- **Growing ecosystem**: More diverse simulation tasks
+- **Standardization**: Common `make_env` API
+- **Collaboration**: Fork and improve existing environments
+- **Accessibility**: Lower barrier to sharing research
+
+## Troubleshooting
+
+### "Refusing to execute remote code"
+
+You must explicitly pass `trust_remote_code=True`:
+
+```python
+env = make_env("user/repo", trust_remote_code=True)
+```
+
+### "Module X not found"
+
+The hub environment has dependencies you need to install:
+
+```bash
+# Check the repo's requirements.txt and install dependencies
+pip install gymnasium numpy
+```
+
+### "make_env not found in module"
+
+Your `env.py` must expose a `make_env` function:
+
+```python
+def make_env(n_envs: int, use_async_envs: bool):
+ # Your implementation
+ pass
+```
+
+### Environment returns wrong type
+
+The `make_env` function must return:
+
+- A `gym.vector.VectorEnv`, or
+- A single `gym.Env`, or
+- A dict `{suite_name: {task_id: VectorEnv}}`
+
+## Best Practices
+
+1. **Document your environment**: Include observation/action space descriptions, reward structure, and termination conditions in your README
+2. **Add requirements.txt**: List all dependencies with versions
+3. **Test thoroughly**: Verify your environment works locally before pushing
+4. **Use semantic versioning**: Tag releases with version numbers
+5. **Add examples**: Include usage examples in your README
+6. **Keep it simple**: Minimize dependencies when possible
+7. **License your work**: Add a LICENSE file to clarify usage terms
+
+## Future Directions
+
+The EnvHub ecosystem enables exciting possibilities:
+
+- **GPU-accelerated physics**: Share Isaac Gym or Brax environments
+- **Photorealistic rendering**: Distribute environments with advanced graphics
+- **Multi-agent scenarios**: Complex interaction tasks
+- **Real-world simulators**: Digital twins of physical setups
+- **Procedural generation**: Infinite task variations
+- **Domain randomization**: Pre-configured DR pipelines
+
+As more researchers and developers contribute, the diversity and quality of available environments will grow, benefiting the entire robotics learning community.
+
+## See Also
+
+- [Hugging Face Hub Documentation](https://huggingface.co/docs/hub/en/index)
+- [Gymnasium Documentation](https://gymnasium.farama.org/index.html)
+- [Example Hub Environment](https://huggingface.co/lerobot/cartpole-env)
diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx
new file mode 100644
index 000000000..729a64656
--- /dev/null
+++ b/docs/source/groot.mdx
@@ -0,0 +1,125 @@
+# GR00T N1.5 Policy
+
+GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
+
+This document outlines the specifics of its integration and usage within the LeRobot framework.
+
+## Model Overview
+
+NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
+
+Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
+
+GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
+
+Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
+
+- Real captured data from robots.
+- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
+- Internet-scale video data.
+
+This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
+
+## Installation Requirements
+
+As of today, GR00T N1.5 requires flash attention for it's internal working.
+
+We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
+
+1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
+2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
+
+```bash
+# Check https://pytorch.org/get-started/locally/ for your system
+pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
+pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
+pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
+python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
+```
+
+3. Install LeRobot by running:
+
+```bash
+pip install lerobot[groot]
+```
+
+## Usage
+
+To use GR00T in your LeRobot configuration, specify the policy type as:
+
+```python
+policy.type=groot
+```
+
+## Training
+
+### Training Command Example
+
+Here's a complete training command for finetuning the base GR00T model on your own dataset:
+
+```bash
+# Using a multi-GPU setup
+accelerate launch \
+ --multi_gpu \
+ --num_processes=$NUM_GPUS \
+ $(which lerobot-train) \
+ --output_dir=$OUTPUT_DIR \
+ --save_checkpoint=true \
+ --batch_size=$BATCH_SIZE \
+ --steps=$NUM_STEPS \
+ --save_freq=$SAVE_FREQ \
+ --log_freq=$LOG_FREQ \
+ --policy.push_to_hub=true \
+ --policy.type=groot \
+ --policy.repo_id=$REPO_ID \
+ --policy.tune_diffusion_model=false \
+ --dataset.repo_id=$DATASET_ID \
+ --wandb.enable=true \
+ --wandb.disable_artifact=true \
+ --job_name=$JOB_NAME
+```
+
+## Performance Results
+
+### Libero Benchmark Results
+
+> [!NOTE]
+> Follow our instructions for Libero usage: [Libero](./libero)
+
+GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
+
+| Benchmark | LeRobot Implementation | GR00T Reference |
+| ------------------ | ---------------------- | --------------- |
+| **Libero Spatial** | 82.0% | 92.0% |
+| **Libero Object** | 99.0% | 92.0% |
+| **Libero Long** | 82.0% | 76.0% |
+| **Average** | 87.0% | 87.0% |
+
+These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
+
+### Evaluate in your hardware setup
+
+Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
+
+```bash
+lerobot-record \
+ --robot.type=bi_so100_follower \
+ --robot.left_arm_port=/dev/ttyACM1 \
+ --robot.right_arm_port=/dev/ttyACM0 \
+ --robot.id=bimanual_follower \
+ --robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
+ left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
+ top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
+ }' \
+ --display_data=true \
+ --dataset.repo_id=/eval_groot-bimanual \
+ --dataset.num_episodes=10 \
+ --dataset.single_task="Grab and handover the red cube to the other arm"
+ --policy.path=/groot-bimanual # your trained model
+ --dataset.episode_time_s=30
+ --dataset.reset_time_s=10
+```
+
+## License
+
+This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx
index 0d8fd56e5..000df2a19 100644
--- a/docs/source/il_robots.mdx
+++ b/docs/source/il_robots.mdx
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
-HF_USER=$(huggingface-cli whoami | head -n 1)
+HF_USER=$(hf auth whoami | head -n 1)
echo $HF_USER
```
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index f5fd09acd..da42c90ef 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -1,8 +1,15 @@
# Installation
+## Install [`miniforge`](https://conda-forge.org/download/)
+
+```bash
+wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
+bash Miniforge3-$(uname)-$(uname -m).sh
+```
+
## Environment Setup
-Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
+Create a virtual environment with Python 3.10, using conda:
```bash
conda create -y -n lerobot python=3.10
@@ -14,7 +21,7 @@ Then activate your conda environment, you have to do this each time you open a s
conda activate lerobot
```
-When using `miniconda`, install `ffmpeg` in your environment:
+When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -74,6 +81,9 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
+> [!NOTE]
+> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
+
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx
index ed9dc8dd5..e1587be91 100644
--- a/docs/source/integrate_hardware.mdx
+++ b/docs/source/integrate_hardware.mdx
@@ -208,34 +208,36 @@ LeRobot supports saving and loading calibration data automatically. This is usef
```python
-> @property
-> def is_calibrated(self) -> bool:
-> return True
->
-> def calibrate(self) -> None:
-> pass
-> ```
+@property
+def is_calibrated(self) -> bool:
+ return True
+
+def calibrate(self) -> None:
+ pass
+```
+
### `is_calibrated`
This should reflect whether your robot has the required calibration loaded.
-```
-python
+
+```python
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
```
+
### `calibrate()`
The goal of the calibration is twofold:
- - Know the physical range of motion of each motors in order to only send commands within this range.
- - Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
+
+- Know the physical range of motion of each motors in order to only send commands within this range.
+- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
-
```python
def calibrate(self) -> None:
diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx
index d36fe0ce4..d15f7e91f 100644
--- a/docs/source/pi0.mdx
+++ b/docs/source/pi0.mdx
@@ -28,6 +28,11 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
+ > [!NOTE]
+ > For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
+ >
+ > This will be solved in the next patch release
+
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx
index b6267fc5e..29b797935 100644
--- a/docs/source/pi05.mdx
+++ b/docs/source/pi05.mdx
@@ -36,6 +36,11 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
+ > [!NOTE]
+ > For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
+ >
+ > This will be solved in the next patch release
+
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
diff --git a/docs/source/policy_groot_README.md b/docs/source/policy_groot_README.md
new file mode 100644
index 000000000..efcd76ebe
--- /dev/null
+++ b/docs/source/policy_groot_README.md
@@ -0,0 +1,27 @@
+## Research Paper
+
+Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
+
+## Repository
+
+Code: https://github.com/NVIDIA/Isaac-GR00T
+
+## Citation
+
+```bibtex
+@inproceedings{gr00tn1_2025,
+ archivePrefix = {arxiv},
+ eprint = {2503.14734},
+ title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
+ author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
+ month = {March},
+ year = {2025},
+ booktitle = {ArXiv Preprint},
+}
+```
+
+## Additional Resources
+
+Blog: https://developer.nvidia.com/isaac/gr00t
+
+Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
diff --git a/docs/source/rtc.mdx b/docs/source/rtc.mdx
new file mode 100644
index 000000000..f63c00fca
--- /dev/null
+++ b/docs/source/rtc.mdx
@@ -0,0 +1,188 @@
+# Real-Time Chunking (RTC)
+
+Real-Time Chunking (RTC) is an inference-time method that allows large, flow-matching based robotic policies, such as [Pi0](./pi0), [Pi0.5](./pi05), and [SmolVLA](./smolvla), to produce smooth, continuous, and reactive motion despite having high inference latency.
+
+These policies generate chunks of future actions (e.g., 50 steps at a time) instead of single actions.
+Because the models are large, producing each chunk takes longer than the time it takes the robot to execute it.
+Naively executing chunks leads to problems such as pauses, jerky transitions, or sudden changes in strategy whenever the next chunk arrives late or disagrees with the previously executed actions.
+
+RTC solves this by asynchronously generating the next chunk while the robot continues executing the current one, and by guiding the new chunk so it aligns smoothly with the portion of the previous chunk that has already been executed.
+
+## How RTC Works (simplified)
+
+RTC lets the robot think ahead while it’s still moving. When the robot is carrying out one chunk of actions, RTC starts creating the next chunk early.
+But since the robot has already moved a bit by the time the new chunk is ready, RTC has to make sure the new chunk still lines up smoothly with what the robot is currently doing.
+
+To do this, RTC treats the beginning of the new chunk like an inpainting or “fill-in-the-gaps” problem:
+it gently adjusts the first part of the new chunk so it blends naturally with the robot’s ongoing motion. The result is no pauses, no sudden jumps.
+
+In technical terms, RTC adds a guidance term to the flow-matching denoising process that forces the overlapping timesteps of the new chunk to stay close to the executed portion of the previous chunk, typically using a soft transition mask.
+
+## Quick Start
+
+### Installation
+
+RTC is built into LeRobot. Just install the policy dependencies you need:
+
+```bash
+# For Pi0 or Pi0.5
+pip install -e ".[pi]"
+
+# For SmolVLA
+pip install -e ".[smolvla]"
+```
+
+### Using RTC with Pi0
+
+You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
+The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
+
+```python
+from lerobot.policies.pi0 import PI0Policy, PI0Config
+from lerobot.configs.types import RTCAttentionSchedule
+from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.policies.rtc.action_queue import ActionQueue
+
+# Load Pi0 with RTC enabled
+policy_cfg = PI0Config()
+
+# Enable RTC
+policy_cfg.rtc_config = RTCConfig(
+ enabled=True,
+ execution_horizon=10, # How many steps to blend with previous chunk
+ max_guidance_weight=10.0, # How strongly to enforce consistency
+ prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
+)
+
+# Load the policy
+policy = PI0Policy.from_pretrained("lerobot/pi0_base", policy_cfg=policy_cfg, device="cuda")
+
+# Now use predict_action_chunk with RTC parameters
+inference_delay = 4 # How many steps of inference latency, this values should be calculated based on the inference latency of the policy
+
+# Initialize the action queue
+action_queue = ActionQueue(policy_cfg.rtc_config)
+
+# Start in a separate thread with the following function
+def get_actions():
+ while True:
+ if should_get_actions:
+
+ prev_actions = action_queue.get_left_over()
+ obs = get_robot_observations(robot)
+
+ # Generate actions WITH RTC
+ actions = policy.predict_action_chunk(
+ obs,
+ inference_delay=inference_delay,
+ prev_chunk_left_over=prev_actions,
+ )
+
+ action_queue.merge(
+ actions, actions, inference_delay
+ )
+
+for step in range(num_steps):
+ action = action_queue.get()
+
+ # Execute the first N actions
+ execute_actions(action)
+```
+
+## Key Parameters
+
+`RTCConfig` has the following parameters to tune:
+
+**`execution_horizon`**: How many timesteps from the previous chunk to maintain consistency with. Higher values mean smoother transitions but potentially less reactivity.
+
+Typical values: 8-12 steps
+
+```python
+RTCConfig(execution_horizon=10)
+```
+
+**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
+
+**`prefix_attention_schedule`**: How to weight consistency across the overlap region.
+
+- `LINEAR`: Linear decay from inference_delay to execution_horizon
+- `EXP`: Exponential decay (recommended for getting started)
+- `ONES`: Full weight across entire execution_horizon
+- `ZEROS`: Binary (full weight up to inference_delay, then zero)
+
+**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.
+
+## Testing RTC Offline
+
+Before running on a real robot, test RTC with dataset samples to visualize how it works:
+
+```bash
+python examples/rtc/eval_dataset.py \
+ --policy.path=lerobot/pi0_libero_finetuned \
+ --dataset.repo_id=HuggingFaceVLA/libero \
+ --rtc.execution_horizon=10 \
+ --rtc.max_guidance_weight=10.0 \
+ --device=cuda
+```
+
+The script generates a visualization of the denoising process, comparing standard generation (left) with RTC (right). In the RTC plots, you can see how the first few steps (blue/purple lines) are guided to match the red ground truth trajectory (previous chunk's tail), ensuring a smooth transition between chunks.
+
+
+
+
+
+## Testing RTC with a Real Robot
+
+```bash
+python examples/rtc/eval_with_real_robot.py \
+ --policy.path=${HF_USERNAME}/policy_repo_id \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58FA0834591 \
+ --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
+ --task="Move green small object into the purple platform" \
+ --duration=120 \
+ --device=cuda
+```
+
+## How It Differs from the Async Inference in LeRobot
+
+Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
+
+| Aspect | Async Inference | RTC |
+| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
+| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
+| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
+| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
+| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
+
+**Use both together** for maximum smoothness and reactivity!
+
+## Advanced: Debug Tracking
+
+RTC includes built-in debug tracking to help you understand what's happening during inference:
+
+```python
+# Enable debug tracking
+policy_cfg.rtc_config.debug = True
+policy_cfg.rtc_config.debug_maxlen = 100
+
+# After inference, access debug data
+debug_data = policy.rtc_processor.get_debug_data()
+
+# Visualize denoising steps, corrections, etc.
+from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
+visualizer = RTCDebugVisualizer()
+# ... create plots
+```
+
+See `examples/rtc/eval_dataset.py` for a complete example of visualization.
+
+## References
+
+- [Smooth-As-Butter Robot Policies](https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html) - Excellent technical explanation with real robot results
+- [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/research/real_time_chunking) - Original paper and research
+- [Kinetix RTC Implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) - Reference implementation from Physical Intelligence
diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py
index a96c170cf..a69169814 100644
--- a/examples/dataset/load_lerobot_dataset.py
+++ b/examples/dataset/load_lerobot_dataset.py
@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
-# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
-# PyTorch datasets.
-dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=4,
- batch_size=32,
- shuffle=True,
-)
-
-for batch in dataloader:
- print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
- print(f"{batch['observation.state'].shape=}") # (32, 6, c)
- print(f"{batch['action'].shape=}") # (32, 64, c)
- break
+if __name__ == "__main__":
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=4,
+ batch_size=32,
+ shuffle=True,
+ )
+ for batch in dataloader:
+ print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
+ print(f"{batch['observation.state'].shape=}") # (32, 6, c)
+ print(f"{batch['action'].shape=}") # (32, 64, c)
+ break
diff --git a/examples/port_datasets/slurm_aggregate_shards.py b/examples/port_datasets/slurm_aggregate_shards.py
index 4e1b71a31..af5473c79 100644
--- a/examples/port_datasets/slurm_aggregate_shards.py
+++ b/examples/port_datasets/slurm_aggregate_shards.py
@@ -15,16 +15,12 @@
# limitations under the License.
import argparse
-import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
-from port_datasets.droid_rlds.port_droid import DROID_SHARDS
-
-from lerobot.datasets.aggregate import aggregate_datasets
-from lerobot.utils.utils import init_logging
+from port_droid import DROID_SHARDS
class AggregateDatasets(PipelineStep):
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
+ import logging
+
+ from lerobot.datasets.aggregate import aggregate_datasets
+ from lerobot.utils.utils import init_logging
+
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
diff --git a/examples/port_datasets/slurm_port_shards.py b/examples/port_datasets/slurm_port_shards.py
index 3bb4c135c..657ea870c 100644
--- a/examples/port_datasets/slurm_port_shards.py
+++ b/examples/port_datasets/slurm_port_shards.py
@@ -20,7 +20,7 @@ from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
-from port_datasets.droid_rlds.port_droid import DROID_SHARDS
+from port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
- from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
+ from port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
diff --git a/examples/port_datasets/slurm_upload.py b/examples/port_datasets/slurm_upload.py
index ade1ef874..55002c0be 100644
--- a/examples/port_datasets/slurm_upload.py
+++ b/examples/port_datasets/slurm_upload.py
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
-from port_datasets.droid_rlds.port_droid import DROID_SHARDS
+from port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
def make_upload_executor(
- repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
+ repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
):
kwargs = {
"pipeline": [
- UploadDataset(repo_id),
+ UploadDataset(repo_id, private=private),
],
"logging_dir": str(logs_dir / job_name),
}
@@ -267,6 +267,12 @@ def main():
default="1950M",
help="Memory per cpu that each worker will use.",
)
+ parser.add_argument(
+ "--private",
+ action="store_true",
+ default=False,
+ help="Whether to create a private repository.",
+ )
init_logging()
diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py
new file mode 100644
index 000000000..4652df107
--- /dev/null
+++ b/examples/rtc/eval_dataset.py
@@ -0,0 +1,951 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Evaluate Real-Time Chunking (RTC) performance on dataset samples.
+
+This script takes two random samples from a dataset:
+- Uses actions from the first sample as previous chunk
+- Generates new actions for the second sample with and without RTC
+
+It compares action predictions with and without RTC on dataset samples,
+measuring consistency and ground truth alignment.
+
+Usage:
+ # Basic usage with smolvla policy
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=helper2424/smolvla_check_rtc_last3 \
+ --dataset.repo_id=helper2424/check_rtc \
+ --rtc.execution_horizon=8 \
+ --device=mps \
+ --rtc.max_guidance_weight=10.0 \
+ --rtc.prefix_attention_schedule=EXP \
+ --seed=10
+
+ # Basic usage with pi0.5 policy
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=lerobot/pi05_libero_finetuned \
+ --dataset.repo_id=HuggingFaceVLA/libero \
+ --rtc.execution_horizon=10 \
+ --device=mps
+ --seed=10
+
+ # Basic usage with pi0.5 policy with cuda device
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=lerobot/pi05_libero_finetuned \
+ --dataset.repo_id=HuggingFaceVLA/libero \
+ --rtc.execution_horizon=8 \
+ --device=cuda
+
+ # Basic usage with pi0 policy with cuda device
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=lerobot/pi0_libero_finetuned \
+ --dataset.repo_id=HuggingFaceVLA/libero \
+ --rtc.execution_horizon=8 \
+ --device=cuda
+
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=lipsop/reuben_pi0 \
+ --dataset.repo_id=ReubenLim/so101_cube_in_cup \
+ --rtc.execution_horizon=8 \
+ --device=cuda
+
+ # With torch.compile for faster inference (PyTorch 2.0+)
+ # Note: CUDA graphs disabled by default due to in-place ops in denoising loop
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=helper2424/smolvla_check_rtc_last3 \
+ --dataset.repo_id=helper2424/check_rtc \
+ --rtc.execution_horizon=8 \
+ --device=mps \
+ --use_torch_compile=true \
+ --torch_compile_mode=max-autotune
+
+ # With torch.compile on CUDA (CUDA graphs disabled by default)
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=helper2424/smolvla_check_rtc_last3 \
+ --dataset.repo_id=helper2424/check_rtc \
+ --rtc.execution_horizon=8 \
+ --device=cuda \
+ --use_torch_compile=true \
+ --torch_compile_mode=reduce-overhead
+
+ # Enable CUDA graphs (advanced - may cause tensor aliasing errors)
+ uv run python examples/rtc/eval_dataset.py \
+ --policy.path=helper2424/smolvla_check_rtc_last3 \
+ --dataset.repo_id=helper2424/check_rtc \
+ --use_torch_compile=true \
+ --torch_compile_backend=inductor \
+ --torch_compile_mode=max-autotune \
+ --torch_compile_disable_cudagraphs=false
+"""
+
+import gc
+import logging
+import os
+import random
+from dataclasses import dataclass, field
+
+import numpy as np
+import torch
+
+try:
+ import matplotlib.pyplot as plt
+
+ MATPLOTLIB_AVAILABLE = True
+except ImportError:
+ MATPLOTLIB_AVAILABLE = False
+ plt = None
+
+from lerobot.configs import parser
+from lerobot.configs.default import DatasetConfig
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import RTCAttentionSchedule
+from lerobot.datasets.factory import resolve_delta_timestamps
+from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+from lerobot.policies.factory import get_policy_class, make_pre_post_processors
+from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
+from lerobot.utils.hub import HubMixin
+from lerobot.utils.utils import init_logging
+
+
+def set_seed(seed: int):
+ """Set random seed for reproducibility."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if torch.backends.mps.is_available():
+ torch.mps.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def _check_matplotlib_available():
+ """Check if matplotlib is available, raise helpful error if not."""
+ if not MATPLOTLIB_AVAILABLE:
+ raise ImportError(
+ "matplotlib is required for RTC debug visualizations. "
+ "Please install it by running:\n"
+ " uv pip install matplotlib"
+ )
+
+
+@dataclass
+class RTCEvalConfig(HubMixin):
+ """Configuration for RTC evaluation."""
+
+ # Policy configuration
+ policy: PreTrainedConfig | None = None
+
+ # Dataset configuration
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
+
+ # RTC configuration
+ rtc: RTCConfig = field(
+ default_factory=lambda: RTCConfig(
+ enabled=True,
+ execution_horizon=20,
+ max_guidance_weight=10.0,
+ prefix_attention_schedule=RTCAttentionSchedule.EXP,
+ debug=True,
+ debug_maxlen=1000,
+ )
+ )
+
+ # Device configuration
+ device: str | None = field(
+ default=None,
+ metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
+ )
+
+ # Output configuration
+ output_dir: str = field(
+ default="rtc_debug_output",
+ metadata={"help": "Directory to save debug visualizations"},
+ )
+
+ # Seed configuration
+ seed: int = field(
+ default=42,
+ metadata={"help": "Random seed for reproducibility"},
+ )
+
+ inference_delay: int = field(
+ default=4,
+ metadata={"help": "Inference delay for RTC"},
+ )
+
+ # Torch compile configuration
+ use_torch_compile: bool = field(
+ default=False,
+ metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
+ )
+
+ torch_compile_backend: str = field(
+ default="inductor",
+ metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
+ )
+
+ torch_compile_mode: str = field(
+ default="default",
+ metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
+ )
+
+ torch_compile_disable_cudagraphs: bool = field(
+ default=True,
+ metadata={
+ "help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
+ "operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
+ },
+ )
+
+ def __post_init__(self):
+ # Parse policy path
+ policy_path = parser.get_path_arg("policy")
+ if policy_path:
+ cli_overrides = parser.get_cli_overrides("policy")
+ self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
+ self.policy.pretrained_path = policy_path
+ else:
+ raise ValueError("Policy path is required (--policy.path)")
+
+ # Auto-detect device if not specified
+ if self.device is None or self.device == "auto":
+ if torch.cuda.is_available():
+ self.device = "cuda"
+ elif torch.backends.mps.is_available():
+ self.device = "mps"
+ else:
+ self.device = "cpu"
+ logging.info(f"Auto-detected device: {self.device}")
+
+ @classmethod
+ def __get_path_fields__(cls) -> list[str]:
+ """This enables the parser to load config from the policy using `--policy.path=local/dir`"""
+ return ["policy"]
+
+
+class RTCEvaluator:
+ """Evaluator for RTC on dataset samples."""
+
+ def __init__(self, cfg: RTCEvalConfig):
+ self.cfg = cfg
+ self.device = cfg.device
+
+ # Load dataset with proper delta_timestamps based on policy configuration
+ # Calculate delta_timestamps using the same logic as make_dataset factory
+ logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
+
+ # Get dataset metadata to extract FPS
+ ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
+
+ # Calculate delta_timestamps from policy's delta_indices
+ delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
+
+ # Create dataset with calculated delta_timestamps
+ self.dataset = LeRobotDataset(
+ cfg.dataset.repo_id,
+ delta_timestamps=delta_timestamps,
+ )
+ logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
+
+ # Create preprocessor/postprocessor
+ self.preprocessor, self.postprocessor = make_pre_post_processors(
+ policy_cfg=cfg.policy,
+ pretrained_path=cfg.policy.pretrained_path,
+ preprocessor_overrides={
+ "device_processor": {"device": self.device},
+ },
+ )
+
+ logging.info("=" * 80)
+ logging.info("Ready to run evaluation with sequential policy loading:")
+ logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
+ logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
+ logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
+ logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
+ logging.info("=" * 80)
+
+ def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
+ """Initialize a single policy instance with specified RTC configuration.
+
+ Args:
+ name: Name identifier for logging purposes
+ rtc_enabled: Whether to enable RTC for this policy
+ rtc_debug: Whether to enable debug tracking for this policy
+
+ Returns:
+ Configured policy instance with optional torch.compile applied
+ """
+ logging.info(f"Initializing {name}...")
+
+ # Load policy from pretrained
+ policy_class = get_policy_class(self.cfg.policy.type)
+
+ config = PreTrainedConfig.from_pretrained(self.cfg.policy.pretrained_path)
+
+ if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0":
+ config.compile_model = self.cfg.use_torch_compile
+
+ policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config)
+ policy = policy.to(self.device)
+ policy.eval()
+
+ # Configure RTC
+ rtc_config = RTCConfig(
+ enabled=rtc_enabled,
+ execution_horizon=self.cfg.rtc.execution_horizon,
+ max_guidance_weight=self.cfg.rtc.max_guidance_weight,
+ prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
+ debug=rtc_debug,
+ debug_maxlen=self.cfg.rtc.debug_maxlen,
+ )
+ policy.config.rtc_config = rtc_config
+ policy.init_rtc_processor()
+
+ logging.info(f" RTC enabled: {rtc_enabled}")
+ logging.info(f" RTC debug: {rtc_debug}")
+ logging.info(f" Policy config: {config}")
+
+ # Apply torch.compile to predict_action_chunk method if enabled
+ if self.cfg.use_torch_compile:
+ policy = self._apply_torch_compile(policy, name)
+
+ logging.info(f"✓ {name} initialized successfully")
+ return policy
+
+ def _apply_torch_compile(self, policy, policy_name: str):
+ """Apply torch.compile to the policy's predict_action_chunk method.
+
+ Args:
+ policy: Policy instance to compile
+ policy_name: Name for logging purposes
+
+ Returns:
+ Policy with compiled predict_action_chunk method
+ """
+
+ # PI models handle their own compilation
+ if policy.type == "pi05" or policy.type == "pi0":
+ return policy
+
+ try:
+ # Check if torch.compile is available (PyTorch 2.0+)
+ if not hasattr(torch, "compile"):
+ logging.warning(
+ f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
+ f"Current version: {torch.__version__}. Skipping compilation."
+ )
+ return policy
+
+ logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
+ logging.info(f" Backend: {self.cfg.torch_compile_backend}")
+ logging.info(f" Mode: {self.cfg.torch_compile_mode}")
+ logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
+ logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
+
+ # Compile the predict_action_chunk method
+ # - Debug tracker is excluded from compilation via @torch._dynamo.disable
+ # - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
+ compile_kwargs = {
+ "backend": self.cfg.torch_compile_backend,
+ "mode": self.cfg.torch_compile_mode,
+ }
+
+ # Disable CUDA graphs if requested (prevents tensor aliasing issues)
+ if self.cfg.torch_compile_disable_cudagraphs:
+ compile_kwargs["options"] = {"triton.cudagraphs": False}
+
+ original_method = policy.predict_action_chunk
+ compiled_method = torch.compile(original_method, **compile_kwargs)
+ policy.predict_action_chunk = compiled_method
+ logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
+
+ except Exception as e:
+ logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
+ logging.warning(f" [{policy_name}] Continuing without torch.compile")
+
+ return policy
+
+ def _destroy_policy(self, policy, policy_name: str):
+ """Explicitly destroy a policy and free all associated memory.
+
+ This method performs aggressive cleanup to ensure maximum memory is freed,
+ which is critical for large models (e.g., VLAs with billions of parameters).
+
+ Args:
+ policy: Policy instance to destroy
+ policy_name: Name for logging purposes
+ """
+ logging.info(f" Destroying {policy_name} and freeing memory...")
+
+ try:
+ # Step 1: Move policy to CPU to free GPU/MPS memory
+ policy.cpu()
+
+ # Step 2: Delete the policy object
+ del policy
+
+ # Step 3: Force garbage collection to reclaim memory immediately
+ gc.collect()
+
+ # Step 4: Clear device-specific caches
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize() # Ensure all operations complete
+
+ if torch.backends.mps.is_available():
+ torch.mps.empty_cache()
+
+ logging.info(f" ✓ {policy_name} destroyed and memory freed")
+
+ except Exception as e:
+ logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
+
+ def run_evaluation(self):
+ """Run evaluation on two random dataset samples using three separate policies.
+
+ Note: Policies are deinitalized after each step to free memory. Large models
+ (e.g., VLA models with billions of parameters) cannot fit three instances in
+ memory simultaneously. By deleting and garbage collecting after each step,
+ we ensure only one policy is loaded at a time.
+ """
+ # Create output directory
+ os.makedirs(self.cfg.output_dir, exist_ok=True)
+ logging.info(f"Output directory: {self.cfg.output_dir}")
+
+ logging.info("=" * 80)
+ logging.info("Starting RTC evaluation")
+ logging.info(f"Inference delay: {self.cfg.inference_delay}")
+ logging.info("=" * 80)
+
+ # Load two random samples from dataset
+ data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
+ loader_iter = iter(data_loader)
+ first_sample = next(loader_iter)
+ second_sample = next(loader_iter)
+
+ preprocessed_first_sample = self.preprocessor(first_sample)
+ preprocessed_second_sample = self.preprocessor(second_sample)
+
+ # ============================================================================
+ # Step 1: Generate previous chunk using policy_prev_chunk
+ # ============================================================================
+ # This policy is only used to generate the reference chunk and then freed
+ logging.info("=" * 80)
+ logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
+ logging.info("=" * 80)
+
+ # Initialize policy 1
+ policy_prev_chunk_policy = self._init_policy(
+ name="policy_prev_chunk",
+ rtc_enabled=False,
+ rtc_debug=False,
+ )
+ with torch.no_grad():
+ prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
+ preprocessed_first_sample,
+ )[:, :25, :].squeeze(0)
+ logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
+
+ # Destroy policy_prev_chunk to free memory for large models
+ self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
+
+ # ============================================================================
+ # Step 2: Generate actions WITHOUT RTC using policy_no_rtc
+ # ============================================================================
+ logging.info("=" * 80)
+ logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
+ logging.info("=" * 80)
+
+ set_seed(self.cfg.seed)
+
+ # Initialize policy 2
+ policy_no_rtc_policy = self._init_policy(
+ name="policy_no_rtc",
+ rtc_enabled=False,
+ rtc_debug=True,
+ )
+
+ # Sample noise (use same noise for both RTC and non-RTC for fair comparison)
+ noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
+ noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
+ noise_clone = noise.clone()
+ policy_no_rtc_policy.rtc_processor.reset_tracker()
+ with torch.no_grad():
+ no_rtc_actions = policy_no_rtc_policy.predict_action_chunk(
+ preprocessed_second_sample,
+ noise=noise,
+ )
+ no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
+ logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
+ logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}")
+
+ # Destroy policy_no_rtc to free memory before loading policy_rtc
+ self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
+
+ # ============================================================================
+ # Step 3: Generate actions WITH RTC using policy_rtc
+ # ============================================================================
+ logging.info("=" * 80)
+ logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
+ logging.info("=" * 80)
+
+ set_seed(self.cfg.seed)
+
+ # Initialize policy 3
+ policy_rtc_policy = self._init_policy(
+ name="policy_rtc",
+ rtc_enabled=True,
+ rtc_debug=True,
+ )
+ policy_rtc_policy.rtc_processor.reset_tracker()
+ with torch.no_grad():
+ rtc_actions = policy_rtc_policy.predict_action_chunk(
+ preprocessed_second_sample,
+ noise=noise_clone,
+ inference_delay=self.cfg.inference_delay,
+ prev_chunk_left_over=prev_chunk_left_over,
+ execution_horizon=self.cfg.rtc.execution_horizon,
+ )
+ rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
+ logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
+ logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}")
+
+ # Save num_steps before destroying policy (needed for plotting)
+ try:
+ num_steps = policy_rtc_policy.config.num_steps
+ except Exception as e:
+ logging.error(f" Error getting num_steps: {e}")
+ num_steps = policy_rtc_policy.config.num_inference_steps
+ logging.warning(f" Using num_inference_steps: {num_steps} instead of num_steps")
+
+ # Destroy policy_rtc after final use
+ self._destroy_policy(policy_rtc_policy, "policy_rtc")
+
+ # Plot and save results
+ logging.info("=" * 80)
+ logging.info("Plotting results...")
+ self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
+
+ # Plot final actions comparison
+ logging.info("=" * 80)
+ logging.info("Plotting final actions comparison...")
+ self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
+
+ logging.info("=" * 80)
+ logging.info("Evaluation completed successfully")
+
+ def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
+ """Plot final action predictions comparison on a single chart.
+
+ Args:
+ rtc_actions: Final actions from RTC policy
+ no_rtc_actions: Final actions from non-RTC policy
+ prev_chunk_left_over: Previous chunk used as ground truth
+ """
+ _check_matplotlib_available()
+
+ # Remove batch dimension if present
+ rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
+ no_rtc_actions_plot = (
+ no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
+ )
+ prev_chunk_plot = prev_chunk_left_over.cpu()
+
+ # Create figure with 6 subplots (one per action dimension)
+ fig, axes = plt.subplots(6, 1, figsize=(16, 12))
+ fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
+
+ # Plot each action dimension
+ for dim_idx, ax in enumerate(axes):
+ # Plot previous chunk (ground truth) in red
+ RTCDebugVisualizer.plot_waypoints(
+ [ax],
+ prev_chunk_plot[:, dim_idx : dim_idx + 1],
+ start_from=0,
+ color="red",
+ label="Previous Chunk (Ground Truth)",
+ linewidth=2.5,
+ alpha=0.8,
+ )
+
+ # Plot no-RTC actions in blue
+ RTCDebugVisualizer.plot_waypoints(
+ [ax],
+ no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
+ start_from=0,
+ color="blue",
+ label="No RTC",
+ linewidth=2,
+ alpha=0.7,
+ )
+
+ # Plot RTC actions in green
+ RTCDebugVisualizer.plot_waypoints(
+ [ax],
+ rtc_actions_plot[:, dim_idx : dim_idx + 1],
+ start_from=0,
+ color="green",
+ label="RTC",
+ linewidth=2,
+ alpha=0.7,
+ )
+
+ # Add vertical lines for inference delay and execution horizon
+ inference_delay = self.cfg.inference_delay
+ execution_horizon = self.cfg.rtc.execution_horizon
+
+ if inference_delay > 0:
+ ax.axvline(
+ x=inference_delay - 1,
+ color="orange",
+ linestyle="--",
+ alpha=0.5,
+ label=f"Inference Delay ({inference_delay})",
+ )
+
+ if execution_horizon > 0:
+ ax.axvline(
+ x=execution_horizon,
+ color="purple",
+ linestyle="--",
+ alpha=0.5,
+ label=f"Execution Horizon ({execution_horizon})",
+ )
+
+ ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+ # Set x-axis ticks to show all integer values
+ max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
+ ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
+ ax.set_xlim(-0.5, max_len - 0.5)
+
+ axes[-1].set_xlabel("Step", fontsize=10)
+
+ # Collect legend handles and labels from first subplot
+ handles, labels = axes[0].get_legend_handles_labels()
+ # Remove duplicates while preserving order
+ seen = set()
+ unique_handles = []
+ unique_labels = []
+ for handle, label in zip(handles, labels, strict=True):
+ if label not in seen:
+ seen.add(label)
+ unique_handles.append(handle)
+ unique_labels.append(label)
+
+ # Add legend outside the plot area (to the right)
+ fig.legend(
+ unique_handles,
+ unique_labels,
+ loc="center right",
+ fontsize=9,
+ bbox_to_anchor=(1.0, 0.5),
+ framealpha=0.9,
+ )
+
+ # Save figure
+ output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
+ fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
+ logging.info(f"Saved final actions comparison to {output_path}")
+ plt.close(fig)
+
+ def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
+ _check_matplotlib_available()
+
+ # Create side-by-side figures for denoising visualization
+ fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
+ fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
+ fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
+ fig_x1t, axs_x1t = self._create_figure(
+ "x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
+ )
+ self._plot_denoising_steps_from_tracker(
+ rtc_tracked_steps,
+ axs_xt[:, 1], # Right column for x_t
+ axs_vt[:, 1], # Right column for v_t
+ axs_corr[:, 1], # Right column for correction
+ axs_x1t[:, 1], # Right column for x1_t
+ num_steps,
+ add_labels=True, # Add labels for RTC (right column)
+ )
+
+ self._plot_denoising_steps_from_tracker(
+ no_rtc_tracked_steps,
+ axs_xt[:, 0], # Left column for x_t
+ axs_vt[:, 0], # Left column for v_t
+ axs_corr[:, 0], # Left column for correction
+ axs_x1t[:, 0], # Left column for x1_t
+ num_steps,
+ add_labels=False, # No labels for No RTC (left column)
+ )
+
+ # Plot no-RTC x_t data on right chart as orange dashed line for comparison
+ self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
+
+ # Plot ground truth on x_t axes
+ RTCDebugVisualizer.plot_waypoints(
+ axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
+ )
+
+ # Plot ground truth on x1_t axes
+ RTCDebugVisualizer.plot_waypoints(
+ axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
+ )
+
+ # Plot ground truth on x_t axes (no labels for left column)
+ RTCDebugVisualizer.plot_waypoints(
+ axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
+ )
+
+ RTCDebugVisualizer.plot_waypoints(
+ axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
+ )
+
+ # Add legends outside the plot area for each figure
+ self._add_figure_legend(fig_xt, axs_xt)
+ self._add_figure_legend(fig_vt, axs_vt)
+ self._add_figure_legend(fig_corr, axs_corr)
+ self._add_figure_legend(fig_x1t, axs_x1t)
+
+ # Save denoising plots
+ self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
+ self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
+ self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
+ self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
+
+ def _create_figure(self, title):
+ fig, axs = plt.subplots(6, 2, figsize=(24, 12))
+ fig.suptitle(title, fontsize=16)
+
+ for ax in axs[:, 0]:
+ ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
+ for ax in axs[:, 1]:
+ ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
+
+ return fig, axs
+
+ def _add_figure_legend(self, fig, axs):
+ """Add a legend outside the plot area on the right side.
+
+ Args:
+ fig: Matplotlib figure to add legend to
+ axs: Array of axes to collect legend handles from
+ """
+ # Collect all handles and labels from the first row of axes (right column)
+ handles, labels = axs[0, 1].get_legend_handles_labels()
+
+ # Remove duplicates while preserving order
+ seen = set()
+ unique_handles = []
+ unique_labels = []
+ for handle, label in zip(handles, labels, strict=True):
+ if label not in seen:
+ seen.add(label)
+ unique_handles.append(handle)
+ unique_labels.append(label)
+
+ # Add legend outside the plot area (to the right, close to charts)
+ if unique_handles:
+ fig.legend(
+ unique_handles,
+ unique_labels,
+ loc="center left",
+ fontsize=8,
+ bbox_to_anchor=(0.87, 0.5),
+ framealpha=0.9,
+ ncol=1,
+ )
+
+ def _save_figure(self, fig, path):
+ fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
+ fig.savefig(path, dpi=150, bbox_inches="tight")
+ logging.info(f"Saved figure to {path}")
+ plt.close(fig)
+
+ def _plot_denoising_steps_from_tracker(
+ self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
+ ):
+ """Plot denoising steps from tracker data.
+
+ Args:
+ tracked_steps: List of DebugStep objects containing debug steps
+ xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
+ vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
+ corr_axs: Matplotlib axes for correction plots (array of 6 axes)
+ x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
+ num_steps: Total number of denoising steps for colormap
+ add_labels: Whether to add legend labels for the plots
+ """
+
+ logging.info("=" * 80)
+ logging.info(f"Plotting {len(tracked_steps)} steps")
+
+ debug_steps = tracked_steps
+ if not debug_steps:
+ return
+
+ # Define colors for different denoise steps (using a colormap)
+ colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
+
+ for step_idx, debug_step in enumerate(debug_steps):
+ color = colors[step_idx % len(colors)]
+ label = f"Step {step_idx}" if add_labels else None
+
+ # Plot x_t
+ if debug_step.x_t is not None:
+ RTCDebugVisualizer.plot_waypoints(
+ xt_axs, debug_step.x_t, start_from=0, color=color, label=label
+ )
+
+ # Plot v_t
+ if debug_step.v_t is not None:
+ RTCDebugVisualizer.plot_waypoints(
+ vt_axs, debug_step.v_t, start_from=0, color=color, label=label
+ )
+
+ # Plot correction on separate axes
+ if debug_step.correction is not None:
+ RTCDebugVisualizer.plot_waypoints(
+ corr_axs,
+ debug_step.correction,
+ start_from=0,
+ color=color,
+ label=label,
+ )
+
+ # Plot x1_t (predicted state)
+ if x1t_axs is not None and debug_step.x1_t is not None:
+ x1t_label = f"x1_t Step {step_idx}" if add_labels else None
+ RTCDebugVisualizer.plot_waypoints(
+ x1t_axs,
+ debug_step.x1_t,
+ start_from=0,
+ color=color,
+ label=x1t_label,
+ )
+
+ # Plot error in orange dashed
+ if x1t_axs is not None and debug_step.err is not None:
+ error_chunk = (
+ debug_step.err[0].cpu().numpy()
+ if len(debug_step.err.shape) == 3
+ else debug_step.err.cpu().numpy()
+ )
+
+ num_dims = min(error_chunk.shape[-1], 6)
+ error_label = f"error Step {step_idx}" if add_labels else None
+ for j in range(num_dims):
+ x1t_axs[j].plot(
+ np.arange(0, error_chunk.shape[0]),
+ error_chunk[:, j],
+ color="orange",
+ linestyle="--",
+ alpha=0.7,
+ label=error_label,
+ )
+
+ # Recalculate axis limits after plotting to ensure proper scaling
+ self._rescale_axes(xt_axs)
+ self._rescale_axes(vt_axs)
+ self._rescale_axes(corr_axs)
+ self._rescale_axes(x1t_axs)
+
+ def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
+ """Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
+
+ Args:
+ no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
+ xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
+ num_steps: Total number of denoising steps for colormap
+ """
+ debug_steps = no_rtc_tracked_steps
+ if not debug_steps:
+ return
+
+ # Plot only the final x_t step as orange dashed line
+ final_step = debug_steps[-1]
+ logging.info("Plotting final no-RTC x_t step as orange dashed reference")
+
+ if final_step.x_t is not None:
+ x_t_chunk = (
+ final_step.x_t[0].cpu().numpy()
+ if len(final_step.x_t.shape) == 3
+ else final_step.x_t.cpu().numpy()
+ )
+
+ num_dims = min(x_t_chunk.shape[-1], 6)
+ for j in range(num_dims):
+ xt_axs[j].plot(
+ np.arange(0, x_t_chunk.shape[0]),
+ x_t_chunk[:, j],
+ color="orange",
+ linestyle="--",
+ alpha=0.7,
+ linewidth=2,
+ label="No RTC (final)" if j == 0 else "",
+ )
+
+ def _rescale_axes(self, axes):
+ """Rescale axes to show all data with proper margins.
+
+ Args:
+ axes: Array of matplotlib axes to rescale
+ """
+ for ax in axes:
+ ax.relim()
+ ax.autoscale_view()
+
+ # Add 10% margin to y-axis for better visualization
+ ylim = ax.get_ylim()
+ y_range = ylim[1] - ylim[0]
+ if y_range > 0: # Avoid division by zero
+ margin = y_range * 0.1
+ ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
+
+ # Set x-axis ticks to show all integer values
+ xlim = ax.get_xlim()
+ max_len = int(xlim[1]) + 1
+ if max_len > 0:
+ ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
+ ax.set_xlim(-0.5, max_len - 0.5)
+
+
+@parser.wrap()
+def main(cfg: RTCEvalConfig):
+ """Main entry point for RTC evaluation."""
+ # Set random seed for reproducibility
+ set_seed(cfg.seed)
+
+ init_logging()
+
+ logging.info("=" * 80)
+ logging.info("RTC Dataset Evaluation")
+ logging.info(f"Config: {cfg}")
+ logging.info("=" * 80)
+
+ evaluator = RTCEvaluator(cfg)
+ evaluator.run_evaluation()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py
new file mode 100644
index 000000000..6f051485a
--- /dev/null
+++ b/examples/rtc/eval_with_real_robot.py
@@ -0,0 +1,549 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
+
+This script demonstrates:
+1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
+2. Consuming actions from the policy while the robot executes
+3. Periodically requesting new action chunks in the background using threads
+4. Managing action buffers and timing for real-time operation
+
+For simulation environments, see eval_with_simulation.py
+
+Usage:
+ # Run RTC with Real robot with RTC
+ uv run examples/rtc/eval_with_real_robot.py \
+ --policy.path=helper2424/smolvla_check_rtc_last3 \
+ --policy.device=mps \
+ --rtc.enabled=true \
+ --rtc.execution_horizon=20 \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58FA0834591 \
+ --robot.id=so100_follower \
+ --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
+ --task="Move green small object into the purple platform" \
+ --duration=120
+
+ # Run RTC with Real robot without RTC
+ uv run examples/rtc/eval_with_real_robot.py \
+ --policy.path=helper2424/smolvla_check_rtc_last3 \
+ --policy.device=mps \
+ --rtc.enabled=false \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58FA0834591 \
+ --robot.id=so100_follower \
+ --robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
+ --task="Move green small object into the purple platform" \
+ --duration=120
+
+ # Run RTC with Real robot with pi0.5 policy
+ uv run examples/rtc/eval_with_real_robot.py \
+ --policy.path=helper2424/pi05_check_rtc \
+ --policy.device=mps \
+ --rtc.enabled=true \
+ --rtc.execution_horizon=20 \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58FA0834591 \
+ --robot.id=so100_follower \
+ --robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
+ --task="Move green small object into the purple platform" \
+ --duration=120
+"""
+
+import logging
+import math
+import sys
+import time
+import traceback
+from dataclasses import dataclass, field
+from threading import Event, Lock, Thread
+
+import torch
+from torch import Tensor
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.configs import parser
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import RTCAttentionSchedule
+from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
+from lerobot.policies.factory import get_policy_class, make_pre_post_processors
+from lerobot.policies.rtc.action_queue import ActionQueue
+from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.policies.rtc.latency_tracker import LatencyTracker
+from lerobot.processor.factory import (
+ make_default_robot_action_processor,
+ make_default_robot_observation_processor,
+)
+from lerobot.rl.process import ProcessSignalHandler
+from lerobot.robots import ( # noqa: F401
+ Robot,
+ RobotConfig,
+ koch_follower,
+ so100_follower,
+ so101_follower,
+)
+from lerobot.robots.utils import make_robot_from_config
+from lerobot.utils.constants import OBS_IMAGES
+from lerobot.utils.hub import HubMixin
+from lerobot.utils.utils import init_logging
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class RobotWrapper:
+ def __init__(self, robot: Robot):
+ self.robot = robot
+ self.lock = Lock()
+
+ def get_observation(self) -> dict[str, Tensor]:
+ with self.lock:
+ return self.robot.get_observation()
+
+ def send_action(self, action: Tensor):
+ with self.lock:
+ self.robot.send_action(action)
+
+ def observation_features(self) -> list[str]:
+ with self.lock:
+ return self.robot.observation_features
+
+ def action_features(self) -> list[str]:
+ with self.lock:
+ return self.robot.action_features
+
+
+@dataclass
+class RTCDemoConfig(HubMixin):
+ """Configuration for RTC demo with action chunking policies and real robots."""
+
+ # Policy configuration
+ policy: PreTrainedConfig | None = None
+
+ # Robot configuration
+ robot: RobotConfig | None = None
+
+ # RTC configuration
+ rtc: RTCConfig = field(
+ default_factory=lambda: RTCConfig(
+ execution_horizon=10,
+ max_guidance_weight=1.0,
+ prefix_attention_schedule=RTCAttentionSchedule.EXP,
+ )
+ )
+
+ # Demo parameters
+ duration: float = 30.0 # Duration to run the demo (seconds)
+ fps: float = 10.0 # Action execution frequency (Hz)
+
+ # Compute device
+ device: str | None = None # Device to run on (cuda, cpu, auto)
+
+ # Get new actions horizon. The amount of executed steps after which will be requested new actions.
+ # It should be higher than inference delay + execution horizon.
+ action_queue_size_to_get_new_actions: int = 30
+
+ # Task to execute
+ task: str = field(default="", metadata={"help": "Task to execute"})
+
+ # Torch compile configuration
+ use_torch_compile: bool = field(
+ default=False,
+ metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
+ )
+
+ torch_compile_backend: str = field(
+ default="inductor",
+ metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
+ )
+
+ torch_compile_mode: str = field(
+ default="default",
+ metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
+ )
+
+ torch_compile_disable_cudagraphs: bool = field(
+ default=True,
+ metadata={
+ "help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
+ "operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
+ },
+ )
+
+ def __post_init__(self):
+ # HACK: We parse again the cli args here to get the pretrained path if there was one.
+ policy_path = parser.get_path_arg("policy")
+ if policy_path:
+ cli_overrides = parser.get_cli_overrides("policy")
+ self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
+ self.policy.pretrained_path = policy_path
+ else:
+ raise ValueError("Policy path is required")
+
+ # Validate that robot configuration is provided
+ if self.robot is None:
+ raise ValueError("Robot configuration must be provided")
+
+ @classmethod
+ def __get_path_fields__(cls) -> list[str]:
+ """This enables the parser to load config from the policy using `--policy.path=local/dir`"""
+ return ["policy"]
+
+
+def is_image_key(k: str) -> bool:
+ return k.startswith(OBS_IMAGES)
+
+
+def get_actions(
+ policy,
+ robot: RobotWrapper,
+ robot_observation_processor,
+ action_queue: ActionQueue,
+ shutdown_event: Event,
+ cfg: RTCDemoConfig,
+):
+ """Thread function to request action chunks from the policy.
+
+ Args:
+ policy: The policy instance (SmolVLA, Pi0, etc.)
+ robot: The robot instance for getting observations
+ robot_observation_processor: Processor for raw robot observations
+ action_queue: Queue to put new action chunks
+ shutdown_event: Event to signal shutdown
+ cfg: Demo configuration
+ """
+ try:
+ logger.info("[GET_ACTIONS] Starting get actions thread")
+
+ latency_tracker = LatencyTracker() # Track latency of action chunks
+ fps = cfg.fps
+ time_per_chunk = 1.0 / fps
+
+ dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
+ policy_device = policy.config.device
+
+ # Load preprocessor and postprocessor from pretrained files
+ # The stats are embedded in the processor .safetensors files
+ logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
+
+ preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=cfg.policy,
+ pretrained_path=cfg.policy.pretrained_path,
+ dataset_stats=None, # Will load from pretrained processor files
+ preprocessor_overrides={
+ "device_processor": {"device": cfg.policy.device},
+ },
+ )
+
+ logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
+
+ get_actions_threshold = cfg.action_queue_size_to_get_new_actions
+
+ if not cfg.rtc.enabled:
+ get_actions_threshold = 0
+
+ while not shutdown_event.is_set():
+ if action_queue.qsize() <= get_actions_threshold:
+ current_time = time.perf_counter()
+ action_index_before_inference = action_queue.get_action_index()
+ prev_actions = action_queue.get_left_over()
+
+ inference_latency = latency_tracker.max()
+ inference_delay = math.ceil(inference_latency / time_per_chunk)
+
+ obs = robot.get_observation()
+
+ # Apply robot observation processor
+ obs_processed = robot_observation_processor(obs)
+
+ obs_with_policy_features = build_dataset_frame(
+ dataset_features, obs_processed, prefix="observation"
+ )
+
+ for name in obs_with_policy_features:
+ obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
+ if "image" in name:
+ obs_with_policy_features[name] = (
+ obs_with_policy_features[name].type(torch.float32) / 255
+ )
+ obs_with_policy_features[name] = (
+ obs_with_policy_features[name].permute(2, 0, 1).contiguous()
+ )
+ obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
+ obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
+
+ obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
+ obs_with_policy_features["robot_type"] = (
+ robot.robot.name if hasattr(robot.robot, "name") else ""
+ )
+
+ preproceseded_obs = preprocessor(obs_with_policy_features)
+
+ # Generate actions WITH RTC
+ actions = policy.predict_action_chunk(
+ preproceseded_obs,
+ inference_delay=inference_delay,
+ prev_chunk_left_over=prev_actions,
+ )
+
+ # Store original actions (before postprocessing) for RTC
+ original_actions = actions.squeeze(0).clone()
+
+ postprocessed_actions = postprocessor(actions)
+
+ postprocessed_actions = postprocessed_actions.squeeze(0)
+
+ new_latency = time.perf_counter() - current_time
+ new_delay = math.ceil(new_latency / time_per_chunk)
+ latency_tracker.add(new_latency)
+
+ if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
+ logger.warning(
+ "[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
+ )
+
+ action_queue.merge(
+ original_actions, postprocessed_actions, new_delay, action_index_before_inference
+ )
+ else:
+ # Small sleep to prevent busy waiting
+ time.sleep(0.1)
+
+ logger.info("[GET_ACTIONS] get actions thread shutting down")
+ except Exception as e:
+ logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
+ logger.error(traceback.format_exc())
+ sys.exit(1)
+
+
+def actor_control(
+ robot: RobotWrapper,
+ robot_action_processor,
+ action_queue: ActionQueue,
+ shutdown_event: Event,
+ cfg: RTCDemoConfig,
+):
+ """Thread function to execute actions on the robot.
+
+ Args:
+ robot: The robot instance
+ action_queue: Queue to get actions from
+ shutdown_event: Event to signal shutdown
+ cfg: Demo configuration
+ """
+ try:
+ logger.info("[ACTOR] Starting actor thread")
+
+ action_count = 0
+ action_interval = 1.0 / cfg.fps
+
+ while not shutdown_event.is_set():
+ start_time = time.perf_counter()
+
+ # Try to get an action from the queue with timeout
+ action = action_queue.get()
+
+ if action is not None:
+ action = action.cpu()
+ action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
+ action_processed = robot_action_processor((action_dict, None))
+ robot.send_action(action_processed)
+
+ action_count += 1
+
+ dt_s = time.perf_counter() - start_time
+ time.sleep(max(0, (action_interval - dt_s) - 0.001))
+
+ logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
+ except Exception as e:
+ logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
+ logger.error(traceback.format_exc())
+ sys.exit(1)
+
+
+def _apply_torch_compile(policy, cfg: RTCDemoConfig):
+ """Apply torch.compile to the policy's predict_action_chunk method.
+
+ Args:
+ policy: Policy instance to compile
+ cfg: Configuration containing torch compile settings
+
+ Returns:
+ Policy with compiled predict_action_chunk method
+ """
+
+ # PI models handle their own compilation
+ if policy.type == "pi05" or policy.type == "pi0":
+ return policy
+
+ try:
+ # Check if torch.compile is available (PyTorch 2.0+)
+ if not hasattr(torch, "compile"):
+ logger.warning(
+ f"torch.compile is not available. Requires PyTorch 2.0+. "
+ f"Current version: {torch.__version__}. Skipping compilation."
+ )
+ return policy
+
+ logger.info("Applying torch.compile to predict_action_chunk...")
+ logger.info(f" Backend: {cfg.torch_compile_backend}")
+ logger.info(f" Mode: {cfg.torch_compile_mode}")
+ logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
+
+ # Compile the predict_action_chunk method
+ # - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
+ compile_kwargs = {
+ "backend": cfg.torch_compile_backend,
+ "mode": cfg.torch_compile_mode,
+ }
+
+ # Disable CUDA graphs if requested (prevents tensor aliasing issues)
+ if cfg.torch_compile_disable_cudagraphs:
+ compile_kwargs["options"] = {"triton.cudagraphs": False}
+
+ original_method = policy.predict_action_chunk
+ compiled_method = torch.compile(original_method, **compile_kwargs)
+ policy.predict_action_chunk = compiled_method
+ logger.info("✓ Successfully compiled predict_action_chunk")
+
+ except Exception as e:
+ logger.error(f"Failed to apply torch.compile: {e}")
+ logger.warning("Continuing without torch.compile")
+
+ return policy
+
+
+@parser.wrap()
+def demo_cli(cfg: RTCDemoConfig):
+ """Main entry point for RTC demo with draccus configuration."""
+
+ # Initialize logging
+ init_logging()
+
+ logger.info(f"Using device: {cfg.device}")
+
+ # Setup signal handler for graceful shutdown
+ signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
+ shutdown_event = signal_handler.shutdown_event
+
+ policy = None
+ robot = None
+ get_actions_thread = None
+ actor_thread = None
+
+ policy_class = get_policy_class(cfg.policy.type)
+
+ # Load config and set compile_model for pi0/pi05 models
+ config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
+
+ if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
+ config.compile_model = cfg.use_torch_compile
+
+ policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
+
+ # Turn on RTC
+ policy.config.rtc_config = cfg.rtc
+
+ # Init RTC processort, as by default if RTC disabled in the config
+ # The processor won't be created
+ policy.init_rtc_processor()
+
+ assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
+
+ policy = policy.to(cfg.device)
+ policy.eval()
+
+ # Apply torch.compile to predict_action_chunk method if enabled
+ if cfg.use_torch_compile:
+ policy = _apply_torch_compile(policy, cfg)
+
+ # Create robot
+ logger.info(f"Initializing robot: {cfg.robot.type}")
+ robot = make_robot_from_config(cfg.robot)
+ robot.connect()
+ robot_wrapper = RobotWrapper(robot)
+
+ # Create robot observation processor
+ robot_observation_processor = make_default_robot_observation_processor()
+ robot_action_processor = make_default_robot_action_processor()
+
+ # Create action queue for communication between threads
+ action_queue = ActionQueue(cfg.rtc)
+
+ # Start chunk requester thread
+ get_actions_thread = Thread(
+ target=get_actions,
+ args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
+ daemon=True,
+ name="GetActions",
+ )
+ get_actions_thread.start()
+ logger.info("Started get actions thread")
+
+ # Start action executor thread
+ actor_thread = Thread(
+ target=actor_control,
+ args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
+ daemon=True,
+ name="Actor",
+ )
+ actor_thread.start()
+ logger.info("Started actor thread")
+
+ logger.info("Started stop by duration thread")
+
+ # Main thread monitors for duration or shutdown
+ logger.info(f"Running demo for {cfg.duration} seconds...")
+ start_time = time.time()
+
+ while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
+ time.sleep(10)
+
+ # Log queue status periodically
+ if int(time.time() - start_time) % 5 == 0:
+ logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
+
+ if time.time() - start_time > cfg.duration:
+ break
+
+ logger.info("Demo duration reached or shutdown requested")
+
+ # Signal shutdown
+ shutdown_event.set()
+
+ # Wait for threads to finish
+ if get_actions_thread and get_actions_thread.is_alive():
+ logger.info("Waiting for chunk requester thread to finish...")
+ get_actions_thread.join()
+
+ if actor_thread and actor_thread.is_alive():
+ logger.info("Waiting for action executor thread to finish...")
+ actor_thread.join()
+
+ # Cleanup robot
+ if robot:
+ robot.disconnect()
+ logger.info("Robot disconnected")
+
+ logger.info("Cleanup completed")
+
+
+if __name__ == "__main__":
+ demo_cli()
+ logging.info("RTC demo finished")
diff --git a/examples/tutorial/act/act_training_example.py b/examples/tutorial/act/act_training_example.py
new file mode 100644
index 000000000..6db495ca2
--- /dev/null
+++ b/examples/tutorial/act/act_training_example.py
@@ -0,0 +1,98 @@
+"""This script demonstrates how to train ACT Policy on a real-world dataset."""
+
+from pathlib import Path
+
+import torch
+
+from lerobot.configs.types import FeatureType
+from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+from lerobot.datasets.utils import dataset_to_policy_features
+from lerobot.policies.act.configuration_act import ACTConfig
+from lerobot.policies.act.modeling_act import ACTPolicy
+from lerobot.policies.factory import make_pre_post_processors
+
+
+def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
+ if delta_indices is None:
+ return [0]
+
+ return [i / fps for i in delta_indices]
+
+
+output_directory = Path("outputs/robot_learning_tutorial/act")
+output_directory.mkdir(parents=True, exist_ok=True)
+
+# Select your device
+device = torch.device("mps") # or "cuda" or "cpu"
+
+dataset_id = "lerobot/svla_so101_pickplace"
+
+# This specifies the inputs the model will be expecting and the outputs it will produce
+dataset_metadata = LeRobotDatasetMetadata(dataset_id)
+features = dataset_to_policy_features(dataset_metadata.features)
+
+output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
+input_features = {key: ft for key, ft in features.items() if key not in output_features}
+
+cfg = ACTConfig(input_features=input_features, output_features=output_features)
+policy = ACTPolicy(cfg)
+preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
+
+policy.train()
+policy.to(device)
+
+# To perform action chunking, ACT expects a given number of actions as targets
+delta_timestamps = {
+ "action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
+}
+
+# add image features if they are present
+delta_timestamps |= {
+ k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
+}
+
+# Instantiate the dataset
+dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
+
+# Create the optimizer and dataloader for offline training
+optimizer = cfg.get_optimizer_preset().build(policy.parameters())
+batch_size = 32
+dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=device.type != "cpu",
+ drop_last=True,
+)
+
+# Number of training steps and logging frequency
+training_steps = 1
+log_freq = 1
+
+# Run training loop
+step = 0
+done = False
+while not done:
+ for batch in dataloader:
+ batch = preprocessor(batch)
+ loss, _ = policy.forward(batch)
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if step % log_freq == 0:
+ print(f"step: {step} loss: {loss.item():.3f}")
+ step += 1
+ if step >= training_steps:
+ done = True
+ break
+
+# Save the policy checkpoint, alongside the pre/post processors
+policy.save_pretrained(output_directory)
+preprocessor.save_pretrained(output_directory)
+postprocessor.save_pretrained(output_directory)
+
+# Save all assets to the Hub
+policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
+preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
+postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
diff --git a/examples/tutorial/act/act_using_example.py b/examples/tutorial/act/act_using_example.py
new file mode 100644
index 000000000..8beef7654
--- /dev/null
+++ b/examples/tutorial/act/act_using_example.py
@@ -0,0 +1,57 @@
+import torch
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
+from lerobot.policies.act.modeling_act import ACTPolicy
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.policies.utils import build_inference_frame, make_robot_action
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+
+device = torch.device("mps") # or "cuda" or "cpu"
+model_id = "fracapuano/robot_learning_tutorial_act"
+model = ACTPolicy.from_pretrained(model_id)
+
+dataset_id = "lerobot/svla_so101_pickplace"
+# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
+dataset_metadata = LeRobotDatasetMetadata(dataset_id)
+preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
+
+# # find ports using lerobot-find-port
+follower_port = ... # something like "/dev/tty.usbmodem58760431631"
+
+# # the robot ids are used the load the right calibration files
+follower_id = ... # something like "follower_so100"
+
+MAX_EPISODES = 5
+MAX_STEPS_PER_EPISODE = 20
+
+# Robot and environment configuration
+# Camera keys must match the name and resolutions of the ones used for training!
+# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
+camera_config = {
+ "side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
+ "up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
+}
+
+robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
+robot = SO100Follower(robot_cfg)
+robot.connect()
+
+for _ in range(MAX_EPISODES):
+ for _ in range(MAX_STEPS_PER_EPISODE):
+ obs = robot.get_observation()
+ obs_frame = build_inference_frame(
+ observation=obs, ds_features=dataset_metadata.features, device=device
+ )
+
+ obs = preprocess(obs_frame)
+
+ action = model.select_action(obs)
+ action = postprocess(action)
+
+ action = make_robot_action(action, dataset_metadata.features)
+
+ robot.send_action(action)
+
+ print("Episode finished! Starting new episode...")
diff --git a/examples/tutorial/async-inf/policy_server.py b/examples/tutorial/async-inf/policy_server.py
new file mode 100644
index 000000000..cd2b1f04c
--- /dev/null
+++ b/examples/tutorial/async-inf/policy_server.py
@@ -0,0 +1,11 @@
+from lerobot.async_inference.configs import PolicyServerConfig
+from lerobot.async_inference.policy_server import serve
+
+host = ... # something like "127.0.0.1" if you're exposing to localhost
+port = ... # something like 8080
+
+config = PolicyServerConfig(
+ host=host,
+ port=port,
+)
+serve(config)
diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py
new file mode 100644
index 000000000..3e46e5e31
--- /dev/null
+++ b/examples/tutorial/async-inf/robot_client.py
@@ -0,0 +1,55 @@
+import threading
+
+from lerobot.async_inference.configs import RobotClientConfig
+from lerobot.async_inference.helpers import visualize_action_queue_size
+from lerobot.async_inference.robot_client import RobotClient
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.robots.so100_follower import SO100FollowerConfig
+
+# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
+# check the config.json on the Hub for the policy you are using to see the expected camera specs
+camera_cfg = {
+ "up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
+ "side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
+}
+
+# # find ports using lerobot-find-port
+follower_port = ... # something like "/dev/tty.usbmodem58760431631"
+
+# # the robot ids are used the load the right calibration files
+follower_id = ... # something like "follower_so100"
+
+robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
+
+server_address = ... # something like "127.0.0.1:8080" if using localhost
+
+# 3. Create client configuration
+client_cfg = RobotClientConfig(
+ robot=robot_cfg,
+ server_address=server_address,
+ policy_device="mps",
+ policy_type="act",
+ pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
+ chunk_size_threshold=0.5, # g
+ actions_per_chunk=50, # make sure this is less than the max actions of the policy
+)
+
+# 4. Create and start client
+client = RobotClient(client_cfg)
+
+# 5. Provide a textual description of the task
+task = ...
+
+if client.start():
+ # Start action receiver thread
+ action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
+ action_receiver_thread.start()
+
+ try:
+ # Run the control loop
+ client.control_loop(task)
+ except KeyboardInterrupt:
+ client.stop()
+ action_receiver_thread.join()
+ # (Optionally) plot the action queue size
+ visualize_action_queue_size(client.action_queue_size)
diff --git a/examples/tutorial/diffusion/diffusion_training_example.py b/examples/tutorial/diffusion/diffusion_training_example.py
new file mode 100644
index 000000000..4c3fac09e
--- /dev/null
+++ b/examples/tutorial/diffusion/diffusion_training_example.py
@@ -0,0 +1,99 @@
+"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
+
+from pathlib import Path
+
+import torch
+
+from lerobot.configs.types import FeatureType
+from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+from lerobot.datasets.utils import dataset_to_policy_features
+from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
+from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
+from lerobot.policies.factory import make_pre_post_processors
+
+
+def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
+ if delta_indices is None:
+ return [0]
+
+ return [i / fps for i in delta_indices]
+
+
+output_directory = Path("outputs/robot_learning_tutorial/diffusion")
+output_directory.mkdir(parents=True, exist_ok=True)
+
+# Select your device
+device = torch.device("mps") # or "cuda" or "cpu"
+
+dataset_id = "lerobot/svla_so101_pickplace"
+
+# This specifies the inputs the model will be expecting and the outputs it will produce
+dataset_metadata = LeRobotDatasetMetadata(dataset_id)
+features = dataset_to_policy_features(dataset_metadata.features)
+
+output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
+input_features = {key: ft for key, ft in features.items() if key not in output_features}
+
+cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
+policy = DiffusionPolicy(cfg)
+preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
+
+policy.train()
+policy.to(device)
+
+# To perform action chunking, ACT expects a given number of actions as targets
+delta_timestamps = {
+ "observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
+ "action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
+}
+
+# add image features if they are present
+delta_timestamps |= {
+ k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
+}
+
+# Instantiate the dataset
+dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
+
+# Create the optimizer and dataloader for offline training
+optimizer = cfg.get_optimizer_preset().build(policy.parameters())
+batch_size = 32
+dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=device.type != "cpu",
+ drop_last=True,
+)
+
+# Number of training steps and logging frequency
+training_steps = 1
+log_freq = 1
+
+# Run training loop
+step = 0
+done = False
+while not done:
+ for batch in dataloader:
+ batch = preprocessor(batch)
+ loss, _ = policy.forward(batch)
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if step % log_freq == 0:
+ print(f"step: {step} loss: {loss.item():.3f}")
+ step += 1
+ if step >= training_steps:
+ done = True
+ break
+
+# Save the policy checkpoint, alongside the pre/post processors
+policy.save_pretrained(output_directory)
+preprocessor.save_pretrained(output_directory)
+postprocessor.save_pretrained(output_directory)
+
+# Save all assets to the Hub
+policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
+preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
+postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
diff --git a/examples/tutorial/diffusion/diffusion_using_example.py b/examples/tutorial/diffusion/diffusion_using_example.py
new file mode 100644
index 000000000..fe516383f
--- /dev/null
+++ b/examples/tutorial/diffusion/diffusion_using_example.py
@@ -0,0 +1,60 @@
+import torch
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
+from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.policies.utils import build_inference_frame, make_robot_action
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+
+device = torch.device("mps") # or "cuda" or "cpu"
+model_id = "fracapuano/robot_learning_tutorial_diffusion"
+
+model = DiffusionPolicy.from_pretrained(model_id)
+
+dataset_id = "lerobot/svla_so101_pickplace"
+# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
+dataset_metadata = LeRobotDatasetMetadata(dataset_id)
+preprocess, postprocess = make_pre_post_processors(
+ model.config, model_id, dataset_stats=dataset_metadata.stats
+)
+
+MAX_EPISODES = 5
+MAX_STEPS_PER_EPISODE = 20
+
+
+# # find ports using lerobot-find-port
+follower_port = ... # something like "/dev/tty.usbmodem58760431631"
+
+# # the robot ids are used the load the right calibration files
+follower_id = ... # something like "follower_so100"
+
+# Robot and environment configuration
+# Camera keys must match the name and resolutions of the ones used for training!
+# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
+camera_config = {
+ "side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
+ "up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
+}
+
+robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
+robot = SO100Follower(robot_cfg)
+robot.connect()
+
+
+for _ in range(MAX_EPISODES):
+ for _ in range(MAX_STEPS_PER_EPISODE):
+ obs = robot.get_observation()
+ obs_frame = build_inference_frame(
+ observation=obs, ds_features=dataset_metadata.features, device=device
+ )
+
+ obs = preprocess(obs_frame)
+
+ action = model.select_action(obs)
+ action = postprocess(action)
+ action = make_robot_action(action, dataset_metadata.features)
+ robot.send_action(action)
+
+ print("Episode finished! Starting new episode...")
diff --git a/examples/tutorial/pi0/using_pi0_example.py b/examples/tutorial/pi0/using_pi0_example.py
new file mode 100644
index 000000000..98fc4b4be
--- /dev/null
+++ b/examples/tutorial/pi0/using_pi0_example.py
@@ -0,0 +1,67 @@
+import torch
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.datasets.utils import hw_to_dataset_features
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.policies.pi0.modeling_pi0 import PI0Policy
+from lerobot.policies.utils import build_inference_frame, make_robot_action
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+
+MAX_EPISODES = 5
+MAX_STEPS_PER_EPISODE = 20
+
+device = torch.device("mps") # or "cuda" or "cpu"
+model_id = "lerobot/pi0_base"
+
+model = PI0Policy.from_pretrained(model_id)
+
+preprocess, postprocess = make_pre_post_processors(
+ model.config,
+ model_id,
+ # This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
+ preprocessor_overrides={"device_processor": {"device": str(device)}},
+)
+
+# find ports using lerobot-find-port
+follower_port = ... # something like "/dev/tty.usbmodem58760431631"
+
+# the robot ids are used the load the right calibration files
+follower_id = ... # something like "follower_so100"
+
+# Robot and environment configuration
+# Camera keys must match the name and resolutions of the ones used for training!
+# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
+camera_config = {
+ "base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
+ "left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
+ "right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
+}
+
+robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
+robot = SO100Follower(robot_cfg)
+robot.connect()
+
+task = "" # something like "pick the red block"
+robot_type = "" # something like "so100_follower" for multi-embodiment datasets
+
+# This is used to match the raw observation keys to the keys expected by the policy
+action_features = hw_to_dataset_features(robot.action_features, "action")
+obs_features = hw_to_dataset_features(robot.observation_features, "observation")
+dataset_features = {**action_features, **obs_features}
+
+for _ in range(MAX_EPISODES):
+ for _ in range(MAX_STEPS_PER_EPISODE):
+ obs = robot.get_observation()
+ obs_frame = build_inference_frame(
+ observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
+ )
+
+ obs = preprocess(obs_frame)
+
+ action = model.select_action(obs)
+ action = postprocess(action)
+ action = make_robot_action(action, dataset_features)
+ robot.send_action(action)
+
+ print("Episode finished! Starting new episode...")
diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py
new file mode 100644
index 000000000..865053d36
--- /dev/null
+++ b/examples/tutorial/rl/hilserl_example.py
@@ -0,0 +1,345 @@
+import multiprocessing as mp
+import signal
+from pathlib import Path
+from queue import Empty, Full
+
+import torch
+import torch.optim as optim
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.utils import hw_to_dataset_features
+from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
+from lerobot.policies.sac.configuration_sac import SACConfig
+from lerobot.policies.sac.modeling_sac import SACPolicy
+from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
+from lerobot.rl.buffer import ReplayBuffer
+from lerobot.rl.gym_manipulator import make_robot_env
+from lerobot.robots.so100_follower import SO100FollowerConfig
+from lerobot.teleoperators.so100_leader import SO100LeaderConfig
+from lerobot.teleoperators.utils import TeleopEvents
+
+LOG_EVERY = 10
+SEND_EVERY = 10
+
+
+def run_learner(
+ transitions_queue: mp.Queue,
+ parameters_queue: mp.Queue,
+ shutdown_event: mp.Event,
+ policy_learner: SACPolicy,
+ online_buffer: ReplayBuffer,
+ offline_buffer: ReplayBuffer,
+ lr: float = 3e-4,
+ batch_size: int = 32,
+ device: torch.device = "mps",
+):
+ """The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
+ for the actor to adopt."""
+ policy_learner.train()
+ policy_learner.to(device)
+
+ # Create Adam optimizer from scratch - simple and clean
+ optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
+
+ print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
+ print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
+
+ training_step = 0
+
+ while not shutdown_event.is_set():
+ # retrieve incoming transitions from the actor process
+ try:
+ transitions = transitions_queue.get(timeout=0.1)
+ for transition in transitions:
+ # HIL-SERL: Add ALL transitions to online buffer
+ online_buffer.add(**transition)
+
+ # HIL-SERL: Add ONLY human intervention transitions to offline buffer
+ is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
+ if is_intervention:
+ offline_buffer.add(**transition)
+ print(
+ f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
+ )
+
+ except Empty:
+ pass # No transitions available, continue
+
+ # Train if we have enough data
+ if len(online_buffer) >= policy_learner.config.online_step_before_learning:
+ # Sample from online buffer (autonomous + human data)
+ online_batch = online_buffer.sample(batch_size // 2)
+
+ # Sample from offline buffer (human demonstrations only, either precollected or at runtime)
+ offline_batch = offline_buffer.sample(batch_size // 2)
+
+ # Combine batches - this is the key HIL-SERL mechanism!
+ batch = {}
+ for key in online_batch:
+ if key in offline_batch:
+ batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
+ else:
+ batch[key] = online_batch[key]
+
+ loss, _ = policy_learner.forward(batch)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ training_step += 1
+
+ if training_step % LOG_EVERY == 0:
+ print(
+ f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
+ f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
+ )
+
+ # Send updated parameters to actor every 10 training steps
+ if training_step % SEND_EVERY == 0:
+ try:
+ state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
+ parameters_queue.put_nowait(state_dict)
+ print("[LEARNER] Sent updated parameters to actor")
+ except Full:
+ # Missing write due to queue not being consumed (should happen rarely)
+ pass
+
+ print("[LEARNER] Learner process finished")
+
+
+def run_actor(
+ transitions_queue: mp.Queue,
+ parameters_queue: mp.Queue,
+ shutdown_event: mp.Event,
+ policy_actor: SACPolicy,
+ reward_classifier: Classifier,
+ env_cfg: HILSerlRobotEnvConfig,
+ device: torch.device = "mps",
+ output_directory: Path | None = None,
+):
+ """The actor process - interacts with environment and collects data.
+ The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
+ policy_actor.eval()
+ policy_actor.to(device)
+
+ reward_classifier.eval()
+ reward_classifier.to(device)
+
+ # Create robot environment inside the actor process
+ env, teleop_device = make_robot_env(env_cfg)
+
+ try:
+ for episode in range(MAX_EPISODES):
+ if shutdown_event.is_set():
+ break
+
+ obs, _info = env.reset()
+ episode_reward = 0.0
+ step = 0
+ episode_transitions = []
+
+ print(f"[ACTOR] Starting episode {episode + 1}")
+
+ while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
+ try:
+ new_params = parameters_queue.get_nowait()
+ policy_actor.load_state_dict(new_params)
+ print("[ACTOR] Updated policy parameters from learner")
+ except Empty: # No new updated parameters available from learner, waiting
+ pass
+
+ # Get action from policy
+ policy_obs = make_policy_obs(obs, device=device)
+ action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
+ action = action_tensor.squeeze(0).cpu().numpy()
+
+ # Step environment
+ next_obs, _env_reward, terminated, truncated, _info = env.step(action)
+ done = terminated or truncated
+
+ # Predict reward
+ policy_next_obs = make_policy_obs(next_obs, device=device)
+ reward = reward_classifier.predict_reward(policy_next_obs)
+
+ if reward >= 1.0 and not done: # success detected! halt episode
+ terminated = True
+ done = True
+
+ # In HIL-SERL, human interventions come from the teleop device
+ is_intervention = False
+ if hasattr(teleop_device, "get_teleop_events"):
+ # Real intervention detection from teleop device
+ teleop_events = teleop_device.get_teleop_events()
+ is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
+
+ # Store transition with intervention metadata
+ transition = {
+ "state": policy_obs,
+ "action": action,
+ "reward": float(reward) if hasattr(reward, "item") else reward,
+ "next_state": policy_next_obs,
+ "done": done,
+ "truncated": truncated,
+ "complementary_info": {
+ "is_intervention": is_intervention,
+ },
+ }
+
+ episode_transitions.append(transition)
+
+ episode_reward += reward
+ step += 1
+
+ obs = next_obs
+
+ if done:
+ break
+
+ # Send episode transitions to learner
+ transitions_queue.put_nowait(episode_transitions)
+
+ except KeyboardInterrupt:
+ print("[ACTOR] Interrupted by user")
+ finally:
+ # Clean up
+ if hasattr(env, "robot") and env.robot.is_connected:
+ env.robot.disconnect()
+ if teleop_device and hasattr(teleop_device, "disconnect"):
+ teleop_device.disconnect()
+ if output_directory is not None:
+ policy_actor.save_pretrained(output_directory)
+ print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
+
+ print("[ACTOR] Actor process finished")
+
+
+def make_policy_obs(obs, device: torch.device = "cpu"):
+ return {
+ "observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
+ **{
+ f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
+ for k in obs["pixels"]
+ },
+ }
+
+
+"""Main function - coordinates actor and learner processes."""
+
+device = "mps" # or "cuda" or "cpu"
+output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
+output_directory.mkdir(parents=True, exist_ok=True)
+
+# find ports using lerobot-find-port
+follower_port = ...
+leader_port = ...
+
+# the robot ids are used the load the right calibration files
+follower_id = ...
+leader_id = ...
+
+# A pretrained model (to be used in-distribution!)
+reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
+reward_classifier = Classifier.from_pretrained(reward_classifier_id)
+
+reward_classifier.to(device)
+reward_classifier.eval()
+
+MAX_EPISODES = 5
+MAX_STEPS_PER_EPISODE = 20
+
+# Robot and environment configuration
+robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
+teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
+processor_cfg = HILSerlProcessorConfig(control_mode="leader")
+
+env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
+
+# Create robot environment
+env, teleop_device = make_robot_env(env_cfg)
+
+obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
+action_features = hw_to_dataset_features(env.robot.action_features, "action")
+
+# Create SAC policy for action selection
+policy_cfg = SACConfig(
+ device=device,
+ input_features=obs_features,
+ output_features=action_features,
+)
+
+policy_actor = SACPolicy(policy_cfg)
+policy_learner = SACPolicy(policy_cfg)
+
+demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
+offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
+
+# Online buffer: initialized from scratch
+online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
+# Offline buffer: Created from dataset (pre-populated it with demonstrations)
+offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
+ lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
+)
+
+# Create communication channels between learner and actor processes
+transitions_queue = mp.Queue(maxsize=10)
+parameters_queue = mp.Queue(maxsize=2)
+shutdown_event = mp.Event()
+
+
+# Signal handler for graceful shutdown
+def signal_handler(sig):
+ print(f"\nSignal {sig} received, shutting down...")
+ shutdown_event.set()
+
+
+signal.signal(signal.SIGINT, signal_handler)
+signal.signal(signal.SIGTERM, signal_handler)
+
+# Create processes
+learner_process = mp.Process(
+ target=run_learner,
+ args=(
+ transitions_queue,
+ parameters_queue,
+ shutdown_event,
+ policy_learner,
+ online_replay_buffer,
+ offline_replay_buffer,
+ ),
+ kwargs={"device": device}, # can run on accelerated hardware for training
+)
+
+actor_process = mp.Process(
+ target=run_actor,
+ args=(
+ transitions_queue,
+ parameters_queue,
+ shutdown_event,
+ policy_actor,
+ reward_classifier,
+ env_cfg,
+ output_directory,
+ ),
+ kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
+)
+
+learner_process.start()
+actor_process.start()
+
+try:
+ # Wait for actor to finish (it controls the episode loop)
+ actor_process.join()
+ shutdown_event.set()
+ learner_process.join(timeout=10)
+
+except KeyboardInterrupt:
+ print("Main process interrupted")
+ shutdown_event.set()
+ actor_process.join(timeout=5)
+ learner_process.join(timeout=10)
+
+finally:
+ if learner_process.is_alive():
+ learner_process.terminate()
+ if actor_process.is_alive():
+ actor_process.terminate()
diff --git a/examples/tutorial/rl/reward_classifier_example.py b/examples/tutorial/rl/reward_classifier_example.py
new file mode 100644
index 000000000..a3d852e30
--- /dev/null
+++ b/examples/tutorial/rl/reward_classifier_example.py
@@ -0,0 +1,62 @@
+import torch
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.policies.factory import make_policy, make_pre_post_processors
+from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
+
+# Device to use for training
+device = "mps" # or "cuda", or "cpu"
+
+# Load the dataset used for training
+repo_id = "lerobot/example_hil_serl_dataset"
+dataset = LeRobotDataset(repo_id)
+
+# Configure the policy to extract features from the image frames
+camera_keys = dataset.meta.camera_keys
+
+config = RewardClassifierConfig(
+ num_cameras=len(camera_keys),
+ device=device,
+ # backbone model to extract features from the image frames
+ model_name="microsoft/resnet-18",
+)
+
+# Make policy, preprocessor, and optimizer
+policy = make_policy(config, ds_meta=dataset.meta)
+optimizer = config.get_optimizer_preset().build(policy.parameters())
+preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
+
+
+classifier_id = "fracapuano/reward_classifier_hil_serl_example"
+
+# Instantiate a dataloader
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
+
+# Training loop
+num_epochs = 5
+for epoch in range(num_epochs):
+ total_loss = 0
+ total_accuracy = 0
+ for batch in dataloader:
+ # Preprocess the batch and move it to the correct device.
+ batch = preprocessor(batch)
+
+ # Forward pass
+ loss, output_dict = policy.forward(batch)
+
+ # Backward pass and optimization
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item()
+ total_accuracy += output_dict["accuracy"]
+
+ avg_loss = total_loss / len(dataloader)
+ avg_accuracy = total_accuracy / len(dataloader)
+ print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
+
+print("Training finished!")
+
+# You can now save the trained policy.
+policy.push_to_hub(classifier_id)
diff --git a/examples/tutorial/smolvla/using_smolvla_example.py b/examples/tutorial/smolvla/using_smolvla_example.py
new file mode 100644
index 000000000..04c327833
--- /dev/null
+++ b/examples/tutorial/smolvla/using_smolvla_example.py
@@ -0,0 +1,66 @@
+import torch
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.datasets.utils import hw_to_dataset_features
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
+from lerobot.policies.utils import build_inference_frame, make_robot_action
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+
+MAX_EPISODES = 5
+MAX_STEPS_PER_EPISODE = 20
+
+device = torch.device("mps") # or "cuda" or "cpu"
+model_id = "lerobot/smolvla_base"
+
+model = SmolVLAPolicy.from_pretrained(model_id)
+
+preprocess, postprocess = make_pre_post_processors(
+ model.config,
+ model_id,
+ # This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
+ preprocessor_overrides={"device_processor": {"device": str(device)}},
+)
+
+# find ports using lerobot-find-port
+follower_port = ... # something like "/dev/tty.usbmodem58760431631"
+
+# the robot ids are used the load the right calibration files
+follower_id = ... # something like "follower_so100"
+
+# Robot and environment configuration
+# Camera keys must match the name and resolutions of the ones used for training!
+# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
+camera_config = {
+ "camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
+ "camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
+}
+
+robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
+robot = SO100Follower(robot_cfg)
+robot.connect()
+
+task = "" # something like "pick the red block"
+robot_type = "" # something like "so100_follower" for multi-embodiment datasets
+
+# This is used to match the raw observation keys to the keys expected by the policy
+action_features = hw_to_dataset_features(robot.action_features, "action")
+obs_features = hw_to_dataset_features(robot.observation_features, "observation")
+dataset_features = {**action_features, **obs_features}
+
+for _ in range(MAX_EPISODES):
+ for _ in range(MAX_STEPS_PER_EPISODE):
+ obs = robot.get_observation()
+ obs_frame = build_inference_frame(
+ observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
+ )
+
+ obs = preprocess(obs_frame)
+
+ action = model.select_action(obs)
+ action = postprocess(action)
+ action = make_robot_action(action, dataset_features)
+ robot.send_action(action)
+
+ print("Episode finished! Starting new episode...")
diff --git a/pyproject.toml b/pyproject.toml
index c8879ac36..0b53457a1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
-version = "0.3.4"
+version = "0.4.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -81,8 +81,8 @@ dependencies = [
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
- "gymnasium>=1.0.0",
- "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
+ "gymnasium>=1.1.1,<2.0.0",
+ "rerun-sdk>=0.24.0,<0.27.0",
# Support dependencies
"deepdiff>=7.0.1,<9.0.0",
@@ -113,17 +113,23 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
-phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
-# stretch = [
-# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
-# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
-# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
-# ] # TODO: Currently not supported
+phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
-hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
+groot = [
+ "lerobot[transformers-dep]",
+ "peft>=0.13.0,<1.0.0",
+ "dm-tree>=0.1.8,<1.0.0",
+ "timm>=1.0.0,<1.1.0",
+ "safetensors>=0.4.3,<1.0.0",
+ "Pillow>=10.0.0,<13.0.0",
+ "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
+ "ninja>=1.11.1,<2.0.0",
+ "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
+]
+hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
@@ -136,8 +142,8 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
-libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
-metaworld = ["metaworld>=3.0.0"]
+libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
+metaworld = ["metaworld==3.0.0"]
# All
all = [
@@ -150,6 +156,7 @@ all = [
"lerobot[intelrealsense]",
"lerobot[pi]",
"lerobot[smolvla]",
+ # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -234,9 +241,6 @@ exclude_dirs = [
"tests",
"benchmarks",
"src/lerobot/datasets/push_dataset_to_hub",
- "src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
- "src/lerobot/policies/pi0/conversion_scripts",
- "src/lerobot/scripts/push_dataset_to_hub.py",
]
skips = ["B101", "B311", "B404", "B603", "B615"]
@@ -251,6 +255,8 @@ default.extend-ignore-identifiers-re = [
"pn",
"ser",
"ein",
+ "thw",
+ "inpt",
]
# TODO: Uncomment when ready to use
@@ -289,7 +295,6 @@ ignore_errors = true
[[tool.mypy.overrides]]
module = "lerobot.envs.*"
-# Enable type checking only for the envs module
ignore_errors = false
@@ -297,17 +302,22 @@ ignore_errors = false
# module = "lerobot.utils.*"
# ignore_errors = false
-# [[tool.mypy.overrides]]
-# module = "lerobot.configs.*"
-# ignore_errors = false
+[[tool.mypy.overrides]]
+module = "lerobot.configs.*"
+ignore_errors = false
+
+# extra strictness for configs
+disallow_untyped_defs = true
+disallow_incomplete_defs = true
+check_untyped_defs = true
# [[tool.mypy.overrides]]
# module = "lerobot.optim.*"
# ignore_errors = false
-# [[tool.mypy.overrides]]
-# module = "lerobot.model.*"
-# ignore_errors = false
+[[tool.mypy.overrides]]
+module = "lerobot.model.*"
+ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.processor.*"
@@ -317,9 +327,9 @@ ignore_errors = false
# module = "lerobot.datasets.*"
# ignore_errors = false
-# [[tool.mypy.overrides]]
-# module = "lerobot.cameras.*"
-# ignore_errors = false
+[[tool.mypy.overrides]]
+module = "lerobot.cameras.*"
+ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
diff --git a/requirements-macos.txt b/requirements-macos.txt
index 07e263da5..dc90416a3 100644
--- a/requirements-macos.txt
+++ b/requirements-macos.txt
@@ -1,3 +1,4 @@
+#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
@@ -12,47 +13,62 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
-accelerate==1.9.0
- # via lerobot
+ # tensorboard
+accelerate==1.11.0
+ # via
+ # lerobot
+ # peft
aiohappyeyeballs==2.6.1
# via aiohttp
-aiohttp==3.12.15
+aiohttp==3.13.1
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
+antlr4-python3-runtime==4.9.3
+ # via
+ # hydra-core
+ # omegaconf
+anyio==4.11.0
+ # via
+ # starlette
+ # watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
-attrs==25.3.0
+attrs==25.4.0
# via
# aiohttp
# dm-tree
# jsonlines
+ # jsonschema
+ # referencing
# rerun-sdk
-av==15.0.0
+av==15.1.0
# via lerobot
-blinker==1.9.0
- # via flask
-certifi==2025.7.14
+bddl==1.0.1
+ # via libero
+certifi==2025.10.5
# via
# requests
# sentry-sdk
-cffi==1.17.1
+cffi==2.0.0
# via pymunk
cfgv==3.4.0
# via pre-commit
-charset-normalizer==3.4.2
+charset-normalizer==3.4.4
# via requests
-click==8.2.1
+click==8.3.0
# via
- # flask
+ # uvicorn
# wandb
cloudpickle==3.1.1
- # via gymnasium
-cmake==4.0.3
+ # via
+ # gymnasium
+ # libero
+cmake==4.1.0
# via lerobot
cmeel==0.57.3
# via
@@ -94,27 +110,27 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
-coverage[toml]==7.10.1
+coverage[toml]==7.11.0
# via pytest-cov
cycler==0.12.1
# via matplotlib
-datasets==3.6.0
+datasets==4.1.1
# via lerobot
-debugpy==1.8.15
+debugpy==1.8.17
# via lerobot
decorator==5.2.1
# via ipython
-deepdiff==8.5.0
+deepdiff==8.6.1
# via lerobot
-diffusers==0.34.0
+diffusers==0.35.2
# via lerobot
-dill==0.3.8
+dill==0.4.0
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
-dm-control==1.0.14
+dm-control==1.0.34
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -122,29 +138,45 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
+ # lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
-dynamixel-sdk==3.7.31
+dynamixel-sdk==3.8.4
# via lerobot
+easydict==1.13
+ # via libero
+egl-probe @ git+https://github.com/huggingface/egl_probe.git
+ # via
+ # libero
+ # robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
- # via lerobot
+ # via
+ # lerobot
+ # libero
eiquadprog==1.2.9
# via placo
+etils[epath,epy]==1.13.0
+ # via mujoco
exceptiongroup==1.3.0
# via
+ # anyio
# ipython
# pytest
-executing==2.2.0
+executing==2.2.1
# via stack-data
farama-notifications==0.0.4
# via gymnasium
+fastapi==0.119.1
+ # via teleop
+fastjsonschema==2.21.2
+ # via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
-filelock==3.18.0
+filelock==3.20.0
# via
# datasets
# diffusers
@@ -152,24 +184,25 @@ filelock==3.18.0
# torch
# transformers
# virtualenv
-flask==3.1.1
- # via lerobot
-fonttools==4.59.0
+fonttools==4.60.1
# via matplotlib
-frozenlist==1.7.0
+frozenlist==1.8.0
# via
# aiohttp
# aiosignal
-fsspec[http]==2025.3.0
+fsspec[http]==2025.9.0
# via
# datasets
+ # etils
# huggingface-hub
# torch
+future==1.0.0
+ # via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
-glfw==2.9.0
+glfw==2.10.0
# via
# dm-control
# mujoco
@@ -177,61 +210,79 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
+ # reachy2-sdk
+ # reachy2-sdk-api
+ # tensorboard
grpcio-tools==1.73.1
+ # via
+ # lerobot
+ # reachy2-sdk-api
+gym-aloha==0.1.3
# via lerobot
-gym-aloha==0.1.1
+gym-hil==0.1.13
# via lerobot
-gym-hil==0.1.10
+gym-pusht==0.1.6
# via lerobot
-gym-pusht==0.1.5
- # via lerobot
-gym-xarm==0.1.1
- # via lerobot
-gymnasium==0.29.1
+gymnasium==1.2.1
# via
# gym-aloha
# gym-hil
# gym-pusht
- # gym-xarm
- # gymnasium-robotics
# lerobot
- # pettingzoo
-gymnasium-robotics==1.2.4
- # via gym-xarm
+ # libero
+ # metaworld
+h11==0.16.0
+ # via uvicorn
+h5py==3.15.1
+ # via robomimic
+hebi-py==2.11.0
+ # via lerobot
hf-transfer==0.1.9
# via huggingface-hub
-hf-xet==1.1.5
+hf-xet==1.1.10
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
-huggingface-hub[cli,hf-transfer]==0.34.3
+httptools==0.7.1
+ # via uvicorn
+huggingface-hub[cli,hf-transfer]==0.35.3
# via
# accelerate
# datasets
# diffusers
# lerobot
+ # peft
+ # timm
# tokenizers
# transformers
-identify==2.6.12
+hydra-core==1.3.2
+ # via libero
+identify==2.6.15
# via pre-commit
-idna==3.10
+idna==3.11
# via
+ # anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
- # gymnasium-robotics
# lerobot
+ # metaworld
+ # robomimic
# scikit-image
imageio-ffmpeg==0.6.0
- # via imageio
+ # via
+ # imageio
+ # robomimic
importlib-metadata==8.7.0
# via diffusers
-iniconfig==2.1.0
+importlib-resources==6.5.2
+ # via etils
+iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -239,50 +290,71 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
-itsdangerous==2.2.0
- # via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
- # via
- # flask
- # gymnasium-robotics
- # torch
+ # via torch
jsonlines==4.0.0
# via lerobot
-kiwisolver==1.4.8
+jsonschema==4.25.1
+ # via nbformat
+jsonschema-specifications==2025.9.1
+ # via jsonschema
+jupyter-core==5.9.1
+ # via nbformat
+jupytext==1.18.1
+ # via bddl
+kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
-lxml==6.0.0
+libero @ git+https://github.com/huggingface/lerobot-libero.git@main
+ # via lerobot
+llvmlite==0.45.1
+ # via numba
+lxml==6.0.2
# via dm-control
-markupsafe==3.0.2
+markdown==3.9
+ # via tensorboard
+markdown-it-py==4.0.0
+ # via
+ # jupytext
+ # mdit-py-plugins
+markupsafe==3.0.3
# via
- # flask
# jinja2
# werkzeug
-matplotlib==3.10.5
- # via lerobot
-matplotlib-inline==0.1.7
+matplotlib==3.10.7
+ # via
+ # lerobot
+ # libero
+matplotlib-inline==0.2.1
# via ipython
+mdit-py-plugins==0.5.0
+ # via jupytext
+mdurl==0.1.2
+ # via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
+metaworld==3.0.0
+ # via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
-mujoco==2.3.7
+mujoco==3.3.7
# via
# dm-control
# gym-aloha
# gym-hil
- # gym-xarm
- # gymnasium-robotics
-multidict==6.6.3
+ # libero
+ # metaworld
+ # robosuite
+multidict==6.7.0
# via
# aiohttp
# yarl
@@ -290,17 +362,25 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
+nbformat==5.10.4
+ # via jupytext
networkx==3.4.2
# via
+ # bddl
# scikit-image
# torch
+ninja==1.13.0
+ # via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
+numba==0.62.1
+ # via robosuite
numpy==2.2.6
# via
# accelerate
+ # bddl
# cmeel-boost
# contourpy
# datasets
@@ -309,25 +389,43 @@ numpy==2.2.6
# dm-env
# dm-tree
# gymnasium
- # gymnasium-robotics
+ # h5py
+ # hebi-py
# imageio
# labmaze
+ # libero
# matplotlib
# meshcat
+ # metaworld
# mujoco
+ # numba
# opencv-python
# opencv-python-headless
# pandas
- # pettingzoo
+ # peft
+ # pyquaternion
+ # reachy2-sdk
# rerun-sdk
+ # robomimic
+ # robosuite
# scikit-image
# scipy
# shapely
+ # teleop
+ # tensorboard
+ # tensorboardx
# tifffile
# torchvision
# transformers
+ # transforms3d
+omegaconf==2.3.0
+ # via hydra-core
opencv-python==4.12.0.88
- # via gym-pusht
+ # via
+ # gym-pusht
+ # libero
+ # reachy2-sdk
+ # robosuite
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -337,53 +435,63 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
+ # hydra-core
+ # jupytext
# lazy-loader
# lerobot
# matplotlib
+ # peft
# pytest
+ # reachy2-sdk
# scikit-image
+ # tensorboard
+ # tensorboardx
# transformers
# wandb
-pandas==2.3.1
+pandas==2.3.3
# via
# datasets
# lerobot
-parso==0.8.4
+parso==0.8.5
# via jedi
-pettingzoo==1.24.3
- # via gymnasium-robotics
+peft==0.17.1
+ # via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
-pillow==11.3.0
+pillow==12.0.0
# via
# diffusers
# imageio
+ # lerobot
# matplotlib
# meshcat
# rerun-sdk
+ # robosuite
# scikit-image
+ # tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
-platformdirs==4.3.8
+platformdirs==4.5.0
# via
+ # jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
-pre-commit==4.2.0
+pre-commit==4.3.0
# via lerobot
-prompt-toolkit==3.0.51
+prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
-propcache==0.3.2
+propcache==0.4.1
# via
# aiohttp
# yarl
@@ -392,11 +500,17 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
+ # reachy2-sdk
+ # reachy2-sdk-api
+ # tensorboard
+ # tensorboardx
# wandb
-psutil==7.0.0
+psutil==7.1.1
# via
# accelerate
# imageio
+ # peft
+ # robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -405,11 +519,13 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
-pycparser==2.22
+pycparser==2.23
# via cffi
-pydantic==2.11.7
- # via wandb
-pydantic-core==2.33.2
+pydantic==2.12.3
+ # via
+ # fastapi
+ # wandb
+pydantic-core==2.41.4
# via pydantic
pygame==2.6.1
# via
@@ -424,40 +540,42 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
-pyngrok==7.2.12
+pyngrok==7.4.1
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
-pyobjc-core==11.1
+pyobjc-core==12.0
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-cocoa
# pyobjc-framework-coretext
# pyobjc-framework-quartz
-pyobjc-framework-applicationservices==11.1
+pyobjc-framework-applicationservices==12.0
# via pynput
-pyobjc-framework-cocoa==11.1
+pyobjc-framework-cocoa==12.0
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
# pyobjc-framework-quartz
-pyobjc-framework-coretext==11.1
+pyobjc-framework-coretext==12.0
# via pyobjc-framework-applicationservices
-pyobjc-framework-quartz==11.1
+pyobjc-framework-quartz==12.0
# via
# pynput
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
-pyopengl==3.1.9
+pyopengl==3.1.10
# via
# dm-control
# mujoco
-pyparsing==3.2.3
+pyparsing==3.2.5
# via
# dm-control
# matplotlib
+pyquaternion==0.9.9
+ # via reachy2-sdk
pyrealsense2-macosx==2.54.2
# via lerobot
pyserial==3.5
@@ -465,12 +583,14 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
-pytest==8.4.1
+pytest==8.4.2
# via
+ # bddl
# lerobot
# pytest-cov
# pytest-timeout
-pytest-cov==6.2.1
+ # teleop
+pytest-cov==7.0.0
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -478,46 +598,73 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
+python-dotenv==1.1.1
+ # via uvicorn
pytz==2025.2
# via pandas
-pyyaml==6.0.2
+pyyaml==6.0.3
# via
# accelerate
# datasets
# draccus
+ # hebi-py
# huggingface-hub
+ # jupytext
+ # omegaconf
+ # peft
# pre-commit
# pyngrok
# pyyaml-include
+ # timm
# transformers
+ # uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
-pyzmq==27.0.0
+pyzmq==27.1.0
# via
# lerobot
# meshcat
-regex==2025.7.34
+reachy2-sdk==1.0.14
+ # via lerobot
+reachy2-sdk-api==1.0.21
+ # via reachy2-sdk
+referencing==0.37.0
+ # via
+ # jsonschema
+ # jsonschema-specifications
+regex==2025.10.23
# via
# diffusers
# transformers
-requests==2.32.4
+requests==2.32.5
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
+ # teleop
# transformers
# wandb
-rerun-sdk==0.22.1
+rerun-sdk==0.26.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
-safetensors==0.5.3
+robomimic==0.2.0
+ # via libero
+robosuite==1.4.0
+ # via libero
+rpds-py==0.28.0
+ # via
+ # jsonschema
+ # referencing
+safetensors==0.6.2
# via
# accelerate
# diffusers
# lerobot
+ # peft
+ # timm
# transformers
scikit-image==0.25.2
# via
@@ -526,10 +673,12 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
+ # metaworld
+ # robosuite
# scikit-image
-sentry-sdk==2.34.1
+sentry-sdk==2.42.1
# via wandb
-shapely==2.1.1
+shapely==2.1.2
# via gym-pusht
six==1.17.0
# via
@@ -537,64 +686,106 @@ six==1.17.0
# python-dateutil
smmap==5.0.2
# via gitdb
+sniffio==1.3.1
+ # via anyio
stack-data==0.6.3
# via ipython
+starlette==0.48.0
+ # via fastapi
sympy==1.14.0
# via torch
-termcolor==3.1.0
+teleop==0.1.2
# via lerobot
+tensorboard==2.20.0
+ # via robomimic
+tensorboard-data-server==0.7.2
+ # via tensorboard
+tensorboardx==2.6.4
+ # via robomimic
+termcolor==3.1.0
+ # via
+ # lerobot
+ # robomimic
+thop==0.1.1.post2209072238
+ # via libero
tifffile==2025.5.10
# via scikit-image
-tokenizers==0.21.4
+timm==1.0.20
+ # via lerobot
+tokenizers==0.22.1
# via transformers
toml==0.10.2
# via draccus
-tomli==2.2.1
+tomli==2.3.0
# via
# cmeel
# coverage
+ # jupytext
# pytest
torch==2.7.1
# via
# accelerate
# lerobot
+ # peft
+ # robomimic
+ # thop
+ # timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
- # via lerobot
-tornado==6.5.1
+ # via
+ # lerobot
+ # robomimic
+ # timm
+tornado==6.5.2
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
+ # peft
+ # robomimic
# transformers
traitlets==5.14.3
# via
# ipython
+ # jupyter-core
# matplotlib-inline
-transformers==4.51.3
- # via lerobot
-typing-extensions==4.14.1
+ # nbformat
+transformers==4.57.1
+ # via
+ # lerobot
+ # libero
+ # peft
+transforms3d==0.4.2
+ # via teleop
+typing-extensions==4.15.0
# via
# aiosignal
+ # anyio
+ # etils
# exceptiongroup
+ # fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
+ # referencing
# rerun-sdk
+ # starlette
# torch
# typing-inspect
# typing-inspection
+ # uvicorn
+ # virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
-typing-inspection==0.4.1
+typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via pandas
@@ -604,22 +795,36 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
-virtualenv==20.32.0
+uvicorn[standard]==0.38.0
+ # via teleop
+uvloop==0.22.1
+ # via uvicorn
+virtualenv==20.35.3
# via pre-commit
-wandb==0.21.0
- # via lerobot
-wcwidth==0.2.13
+wandb==0.21.4
+ # via
+ # lerobot
+ # libero
+watchfiles==1.1.1
+ # via uvicorn
+wcwidth==0.2.14
# via prompt-toolkit
+websocket-client==1.9.0
+ # via teleop
+websockets==15.0.1
+ # via uvicorn
werkzeug==3.1.3
- # via flask
-wrapt==1.17.2
+ # via tensorboard
+wrapt==2.0.0
# via dm-tree
-xxhash==3.5.0
+xxhash==3.6.0
# via datasets
-yarl==1.20.1
+yarl==1.22.0
# via aiohttp
zipp==3.23.0
- # via importlib-metadata
+ # via
+ # etils
+ # importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/requirements-ubuntu.txt b/requirements-ubuntu.txt
index af7258d67..8413feac3 100644
--- a/requirements-ubuntu.txt
+++ b/requirements-ubuntu.txt
@@ -13,47 +13,62 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
-accelerate==1.9.0
- # via lerobot
+ # tensorboard
+accelerate==1.11.0
+ # via
+ # lerobot
+ # peft
aiohappyeyeballs==2.6.1
# via aiohttp
-aiohttp==3.12.15
+aiohttp==3.13.1
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
+antlr4-python3-runtime==4.9.3
+ # via
+ # hydra-core
+ # omegaconf
+anyio==4.11.0
+ # via
+ # starlette
+ # watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
-attrs==25.3.0
+attrs==25.4.0
# via
# aiohttp
# dm-tree
# jsonlines
+ # jsonschema
+ # referencing
# rerun-sdk
-av==15.0.0
+av==15.1.0
# via lerobot
-blinker==1.9.0
- # via flask
-certifi==2025.7.14
+bddl==1.0.1
+ # via libero
+certifi==2025.10.5
# via
# requests
# sentry-sdk
-cffi==1.17.1
+cffi==2.0.0
# via pymunk
cfgv==3.4.0
# via pre-commit
-charset-normalizer==3.4.2
+charset-normalizer==3.4.4
# via requests
-click==8.2.1
+click==8.3.0
# via
- # flask
+ # uvicorn
# wandb
cloudpickle==3.1.1
- # via gymnasium
-cmake==4.0.3
+ # via
+ # gymnasium
+ # libero
+cmake==4.1.0
# via lerobot
cmeel==0.57.3
# via
@@ -95,27 +110,29 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
-coverage[toml]==7.10.1
+coverage[toml]==7.11.0
# via pytest-cov
cycler==0.12.1
# via matplotlib
-datasets==3.6.0
+datasets==4.1.1
# via lerobot
-debugpy==1.8.15
+debugpy==1.8.17
# via lerobot
decorator==5.2.1
# via ipython
-deepdiff==8.5.0
+decord==0.6.0
# via lerobot
-diffusers==0.34.0
+deepdiff==8.6.1
# via lerobot
-dill==0.3.8
+diffusers==0.35.2
+ # via lerobot
+dill==0.4.0
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
-dm-control==1.0.14
+dm-control==1.0.34
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -123,31 +140,48 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
+ # lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
-dynamixel-sdk==3.7.31
+dynamixel-sdk==3.8.4
# via lerobot
+easydict==1.13
+ # via libero
+egl-probe @ git+https://github.com/huggingface/egl_probe.git
+ # via
+ # libero
+ # robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
- # via lerobot
+ # via
+ # flash-attn
+ # lerobot
+ # libero
eiquadprog==1.2.9
# via placo
+etils[epath,epy]==1.13.0
+ # via mujoco
evdev==1.9.2
# via pynput
exceptiongroup==1.3.0
# via
+ # anyio
# ipython
# pytest
-executing==2.2.0
+executing==2.2.1
# via stack-data
farama-notifications==0.0.4
# via gymnasium
+fastapi==0.119.1
+ # via teleop
+fastjsonschema==2.21.2
+ # via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
-filelock==3.18.0
+filelock==3.20.0
# via
# datasets
# diffusers
@@ -155,24 +189,27 @@ filelock==3.18.0
# torch
# transformers
# virtualenv
-flask==3.1.1
+flash-attn==2.8.3
# via lerobot
-fonttools==4.59.0
+fonttools==4.60.1
# via matplotlib
-frozenlist==1.7.0
+frozenlist==1.8.0
# via
# aiohttp
# aiosignal
-fsspec[http]==2025.3.0
+fsspec[http]==2025.9.0
# via
# datasets
+ # etils
# huggingface-hub
# torch
+future==1.0.0
+ # via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
-glfw==2.9.0
+glfw==2.10.0
# via
# dm-control
# mujoco
@@ -180,61 +217,79 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
+ # reachy2-sdk
+ # reachy2-sdk-api
+ # tensorboard
grpcio-tools==1.73.1
+ # via
+ # lerobot
+ # reachy2-sdk-api
+gym-aloha==0.1.3
# via lerobot
-gym-aloha==0.1.1
+gym-hil==0.1.13
# via lerobot
-gym-hil==0.1.10
+gym-pusht==0.1.6
# via lerobot
-gym-pusht==0.1.5
- # via lerobot
-gym-xarm==0.1.1
- # via lerobot
-gymnasium==0.29.1
+gymnasium==1.2.1
# via
# gym-aloha
# gym-hil
# gym-pusht
- # gym-xarm
- # gymnasium-robotics
# lerobot
- # pettingzoo
-gymnasium-robotics==1.2.4
- # via gym-xarm
+ # libero
+ # metaworld
+h11==0.16.0
+ # via uvicorn
+h5py==3.15.1
+ # via robomimic
+hebi-py==2.11.0
+ # via lerobot
hf-transfer==0.1.9
# via huggingface-hub
-hf-xet==1.1.5
+hf-xet==1.1.10
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
-huggingface-hub[cli,hf-transfer]==0.34.3
+httptools==0.7.1
+ # via uvicorn
+huggingface-hub[cli,hf-transfer]==0.35.3
# via
# accelerate
# datasets
# diffusers
# lerobot
+ # peft
+ # timm
# tokenizers
# transformers
-identify==2.6.12
+hydra-core==1.3.2
+ # via libero
+identify==2.6.15
# via pre-commit
-idna==3.10
+idna==3.11
# via
+ # anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
- # gymnasium-robotics
# lerobot
+ # metaworld
+ # robomimic
# scikit-image
imageio-ffmpeg==0.6.0
- # via imageio
+ # via
+ # imageio
+ # robomimic
importlib-metadata==8.7.0
# via diffusers
-iniconfig==2.1.0
+importlib-resources==6.5.2
+ # via etils
+iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -242,50 +297,71 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
-itsdangerous==2.2.0
- # via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
- # via
- # flask
- # gymnasium-robotics
- # torch
+ # via torch
jsonlines==4.0.0
# via lerobot
-kiwisolver==1.4.8
+jsonschema==4.25.1
+ # via nbformat
+jsonschema-specifications==2025.9.1
+ # via jsonschema
+jupyter-core==5.9.1
+ # via nbformat
+jupytext==1.18.1
+ # via bddl
+kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
-lxml==6.0.0
+libero @ git+https://github.com/huggingface/lerobot-libero.git@main
+ # via lerobot
+llvmlite==0.45.1
+ # via numba
+lxml==6.0.2
# via dm-control
-markupsafe==3.0.2
+markdown==3.9
+ # via tensorboard
+markdown-it-py==4.0.0
+ # via
+ # jupytext
+ # mdit-py-plugins
+markupsafe==3.0.3
# via
- # flask
# jinja2
# werkzeug
-matplotlib==3.10.5
- # via lerobot
-matplotlib-inline==0.1.7
+matplotlib==3.10.7
+ # via
+ # lerobot
+ # libero
+matplotlib-inline==0.2.1
# via ipython
+mdit-py-plugins==0.5.0
+ # via jupytext
+mdurl==0.1.2
+ # via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
+metaworld==3.0.0
+ # via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
-mujoco==2.3.7
+mujoco==3.3.7
# via
# dm-control
# gym-aloha
# gym-hil
- # gym-xarm
- # gymnasium-robotics
-multidict==6.6.3
+ # libero
+ # metaworld
+ # robosuite
+multidict==6.7.0
# via
# aiohttp
# yarl
@@ -293,42 +369,63 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
+nbformat==5.10.4
+ # via jupytext
networkx==3.4.2
# via
+ # bddl
# scikit-image
# torch
+ninja==1.13.0
+ # via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
+numba==0.62.1
+ # via robosuite
numpy==2.2.6
# via
# accelerate
+ # bddl
# cmeel-boost
# contourpy
# datasets
+ # decord
# diffusers
# dm-control
# dm-env
# dm-tree
# gymnasium
- # gymnasium-robotics
+ # h5py
+ # hebi-py
# imageio
# labmaze
+ # libero
# matplotlib
# meshcat
+ # metaworld
# mujoco
+ # numba
# opencv-python
# opencv-python-headless
# pandas
- # pettingzoo
+ # peft
+ # pyquaternion
+ # reachy2-sdk
# rerun-sdk
+ # robomimic
+ # robosuite
# scikit-image
# scipy
# shapely
+ # teleop
+ # tensorboard
+ # tensorboardx
# tifffile
# torchvision
# transformers
+ # transforms3d
nvidia-cublas-cu12==12.6.4.1
# via
# nvidia-cudnn-cu12
@@ -366,8 +463,14 @@ nvidia-nvjitlink-cu12==12.6.85
# torch
nvidia-nvtx-cu12==12.6.77
# via torch
+omegaconf==2.3.0
+ # via hydra-core
opencv-python==4.12.0.88
- # via gym-pusht
+ # via
+ # gym-pusht
+ # libero
+ # reachy2-sdk
+ # robosuite
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -377,53 +480,63 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
+ # hydra-core
+ # jupytext
# lazy-loader
# lerobot
# matplotlib
+ # peft
# pytest
+ # reachy2-sdk
# scikit-image
+ # tensorboard
+ # tensorboardx
# transformers
# wandb
-pandas==2.3.1
+pandas==2.3.3
# via
# datasets
# lerobot
-parso==0.8.4
+parso==0.8.5
# via jedi
-pettingzoo==1.24.3
- # via gymnasium-robotics
+peft==0.17.1
+ # via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
-pillow==11.3.0
+pillow==12.0.0
# via
# diffusers
# imageio
+ # lerobot
# matplotlib
# meshcat
# rerun-sdk
+ # robosuite
# scikit-image
+ # tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
-platformdirs==4.3.8
+platformdirs==4.5.0
# via
+ # jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
-pre-commit==4.2.0
+pre-commit==4.3.0
# via lerobot
-prompt-toolkit==3.0.51
+prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
-propcache==0.3.2
+propcache==0.4.1
# via
# aiohttp
# yarl
@@ -432,11 +545,17 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
+ # reachy2-sdk
+ # reachy2-sdk-api
+ # tensorboard
+ # tensorboardx
# wandb
-psutil==7.0.0
+psutil==7.1.1
# via
# accelerate
# imageio
+ # peft
+ # robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -445,11 +564,13 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
-pycparser==2.22
+pycparser==2.23
# via cffi
-pydantic==2.11.7
- # via wandb
-pydantic-core==2.33.2
+pydantic==2.12.3
+ # via
+ # fastapi
+ # wandb
+pydantic-core==2.41.4
# via pydantic
pygame==2.6.1
# via
@@ -464,20 +585,22 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
-pyngrok==7.2.12
+pyngrok==7.4.1
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
-pyopengl==3.1.9
+pyopengl==3.1.10
# via
# dm-control
# mujoco
-pyparsing==3.2.3
+pyparsing==3.2.5
# via
# dm-control
# matplotlib
+pyquaternion==0.9.9
+ # via reachy2-sdk
pyrealsense2==2.56.5.9235
# via lerobot
pyserial==3.5
@@ -485,12 +608,14 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
-pytest==8.4.1
+pytest==8.4.2
# via
+ # bddl
# lerobot
# pytest-cov
# pytest-timeout
-pytest-cov==6.2.1
+ # teleop
+pytest-cov==7.0.0
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -498,48 +623,75 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
+python-dotenv==1.1.1
+ # via uvicorn
python-xlib==0.33
# via pynput
pytz==2025.2
# via pandas
-pyyaml==6.0.2
+pyyaml==6.0.3
# via
# accelerate
# datasets
# draccus
+ # hebi-py
# huggingface-hub
+ # jupytext
+ # omegaconf
+ # peft
# pre-commit
# pyngrok
# pyyaml-include
+ # timm
# transformers
+ # uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
-pyzmq==27.0.0
+pyzmq==27.1.0
# via
# lerobot
# meshcat
-regex==2025.7.34
+reachy2-sdk==1.0.14
+ # via lerobot
+reachy2-sdk-api==1.0.21
+ # via reachy2-sdk
+referencing==0.37.0
+ # via
+ # jsonschema
+ # jsonschema-specifications
+regex==2025.10.23
# via
# diffusers
# transformers
-requests==2.32.4
+requests==2.32.5
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
+ # teleop
# transformers
# wandb
-rerun-sdk==0.22.1
+rerun-sdk==0.26.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
-safetensors==0.5.3
+robomimic==0.2.0
+ # via libero
+robosuite==1.4.0
+ # via libero
+rpds-py==0.28.0
+ # via
+ # jsonschema
+ # referencing
+safetensors==0.6.2
# via
# accelerate
# diffusers
# lerobot
+ # peft
+ # timm
# transformers
scikit-image==0.25.2
# via
@@ -548,10 +700,12 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
+ # metaworld
+ # robosuite
# scikit-image
-sentry-sdk==2.34.1
+sentry-sdk==2.42.1
# via wandb
-shapely==2.1.1
+shapely==2.1.2
# via gym-pusht
six==1.17.0
# via
@@ -560,66 +714,109 @@ six==1.17.0
# python-xlib
smmap==5.0.2
# via gitdb
+sniffio==1.3.1
+ # via anyio
stack-data==0.6.3
# via ipython
+starlette==0.48.0
+ # via fastapi
sympy==1.14.0
# via torch
-termcolor==3.1.0
+teleop==0.1.2
# via lerobot
+tensorboard==2.20.0
+ # via robomimic
+tensorboard-data-server==0.7.2
+ # via tensorboard
+tensorboardx==2.6.4
+ # via robomimic
+termcolor==3.1.0
+ # via
+ # lerobot
+ # robomimic
+thop==0.1.1.post2209072238
+ # via libero
tifffile==2025.5.10
# via scikit-image
-tokenizers==0.21.4
+timm==1.0.20
+ # via lerobot
+tokenizers==0.22.1
# via transformers
toml==0.10.2
# via draccus
-tomli==2.2.1
+tomli==2.3.0
# via
# cmeel
# coverage
+ # jupytext
# pytest
torch==2.7.1
# via
# accelerate
+ # flash-attn
# lerobot
+ # peft
+ # robomimic
+ # thop
+ # timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
- # via lerobot
-tornado==6.5.1
+ # via
+ # lerobot
+ # robomimic
+ # timm
+tornado==6.5.2
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
+ # peft
+ # robomimic
# transformers
traitlets==5.14.3
# via
# ipython
+ # jupyter-core
# matplotlib-inline
-transformers==4.51.3
- # via lerobot
+ # nbformat
+transformers==4.57.1
+ # via
+ # lerobot
+ # libero
+ # peft
+transforms3d==0.4.2
+ # via teleop
triton==3.3.1
# via torch
-typing-extensions==4.14.1
+typing-extensions==4.15.0
# via
# aiosignal
+ # anyio
+ # etils
# exceptiongroup
+ # fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
+ # referencing
# rerun-sdk
+ # starlette
# torch
# typing-inspect
# typing-inspection
+ # uvicorn
+ # virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
-typing-inspection==0.4.1
+typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via pandas
@@ -629,22 +826,36 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
-virtualenv==20.32.0
+uvicorn[standard]==0.38.0
+ # via teleop
+uvloop==0.22.1
+ # via uvicorn
+virtualenv==20.35.3
# via pre-commit
-wandb==0.21.0
- # via lerobot
-wcwidth==0.2.13
+wandb==0.21.4
+ # via
+ # lerobot
+ # libero
+watchfiles==1.1.1
+ # via uvicorn
+wcwidth==0.2.14
# via prompt-toolkit
+websocket-client==1.9.0
+ # via teleop
+websockets==15.0.1
+ # via uvicorn
werkzeug==3.1.3
- # via flask
-wrapt==1.17.2
+ # via tensorboard
+wrapt==2.0.0
# via dm-tree
-xxhash==3.5.0
+xxhash==3.6.0
# via datasets
-yarl==1.20.1
+yarl==1.22.0
# via aiohttp
zipp==3.23.0
- # via importlib-metadata
+ # via
+ # etils
+ # importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/requirements.in b/requirements.in
index 272f7f540..df2a07d67 100644
--- a/requirements.in
+++ b/requirements.in
@@ -1,9 +1,9 @@
# requirements.in
-# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
-# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
+# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
+# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
-# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
-# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
+# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
+# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
-e .[all]
diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py
index f73cbc1da..2158f51ac 100644
--- a/src/lerobot/async_inference/helpers.py
+++ b/src/lerobot/async_inference/helpers.py
@@ -16,7 +16,7 @@ import logging
import logging.handlers
import os
import time
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from pathlib import Path
import torch
@@ -268,6 +268,7 @@ class RemotePolicyConfig:
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
+ rename_map: dict[str, str] = field(default_factory=dict)
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py
index f7e00dea4..ab2e6bcd8 100644
--- a/src/lerobot/async_inference/policy_server.py
+++ b/src/lerobot/async_inference/policy_server.py
@@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path,
- preprocessor_overrides={"device_processor": device_override},
+ preprocessor_overrides={
+ "device_processor": device_override,
+ "rename_observations_processor": {"rename_map": policy_specs.rename_map},
+ },
postprocessor_overrides={"device_processor": device_override},
)
diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py
index e435c7309..bfdb571a7 100644
--- a/src/lerobot/cameras/camera.py
+++ b/src/lerobot/cameras/camera.py
@@ -17,7 +17,7 @@
import abc
from typing import Any
-import numpy as np
+from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from .configs import CameraConfig, ColorMode
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
- def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
+ def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""Capture and return a single frame from the camera.
Args:
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
- def async_read(self, timeout_ms: float = ...) -> np.ndarray:
+ def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
"""Asynchronously capture and return a single frame from the camera.
Args:
diff --git a/src/lerobot/cameras/configs.py b/src/lerobot/cameras/configs.py
index 0488a97ff..056eec314 100644
--- a/src/lerobot/cameras/configs.py
+++ b/src/lerobot/cameras/configs.py
@@ -18,7 +18,7 @@ import abc
from dataclasses import dataclass
from enum import Enum
-import draccus
+import draccus # type: ignore # TODO: add type stubs for draccus
class ColorMode(str, Enum):
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
@dataclass(kw_only=True)
-class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
+class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
fps: int | None = None
width: int | None = None
height: int | None = None
@property
def type(self) -> str:
- return self.get_choice_name(self.__class__)
+ return str(self.get_choice_name(self.__class__))
diff --git a/src/lerobot/cameras/opencv/__init__.py b/src/lerobot/cameras/opencv/__init__.py
index 11d3139fe..377dafc05 100644
--- a/src/lerobot/cameras/opencv/__init__.py
+++ b/src/lerobot/cameras/opencv/__init__.py
@@ -14,3 +14,5 @@
from .camera_opencv import OpenCVCamera
from .configuration_opencv import OpenCVCameraConfig
+
+__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py
index 50e55f0c2..b1043ba64 100644
--- a/src/lerobot/cameras/opencv/camera_opencv.py
+++ b/src/lerobot/cameras/opencv/camera_opencv.py
@@ -25,11 +25,12 @@ from pathlib import Path
from threading import Event, Lock, Thread
from typing import Any
+from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
+
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
-import cv2
-import numpy as np
+import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -121,7 +122,7 @@ class OpenCVCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
- self.latest_frame: np.ndarray | None = None
+ self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -140,7 +141,7 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
- def connect(self, warmup: bool = True):
+ def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
@@ -180,12 +181,14 @@ class OpenCVCamera(Camera):
def _configure_capture_settings(self) -> None:
"""
- Applies the specified FPS, width, and height settings to the connected camera.
+ Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
This method attempts to set the camera properties via OpenCV. It checks if
the camera successfully applied the settings and raises an error if not.
+ FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
Args:
+ fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
fps: The desired frames per second. If None, the setting is skipped.
width: The desired capture width. If None, the setting is skipped.
height: The desired capture height. If None, the setting is skipped.
@@ -199,10 +202,11 @@ class OpenCVCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
- if self.fps is None:
- self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
- else:
- self._validate_fps()
+ # Set FOURCC first (if specified) as it can affect available FPS/resolution options
+ if self.config.fourcc is not None:
+ self._validate_fourcc()
+ if self.videocapture is None:
+ raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -216,18 +220,56 @@ class OpenCVCamera(Camera):
else:
self._validate_width_and_height()
+ if self.fps is None:
+ self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
+ else:
+ self._validate_fps()
+
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""
+ if self.videocapture is None:
+ raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
+
+ if self.fps is None:
+ raise ValueError(f"{self} FPS is not set")
+
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
# Use math.isclose for robust float comparison
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
+ def _validate_fourcc(self) -> None:
+ """Validates and sets the camera's FOURCC code."""
+
+ fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
+
+ if self.videocapture is None:
+ raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
+
+ success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
+ actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
+
+ # Convert actual FOURCC code back to string for comparison
+ actual_fourcc_code_int = int(actual_fourcc_code)
+ actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
+
+ if not success or actual_fourcc != self.config.fourcc:
+ logger.warning(
+ f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
+ f"Continuing with default format."
+ )
+
def _validate_width_and_height(self) -> None:
"""Validates and sets the camera's frame capture width and height."""
+ if self.videocapture is None:
+ raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
+
+ if self.capture_width is None or self.capture_height is None:
+ raise ValueError(f"{self} capture_width or capture_height is not set")
+
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
@@ -258,11 +300,12 @@ class OpenCVCamera(Camera):
"""
found_cameras_info = []
+ targets_to_scan: list[str | int]
if platform.system() == "Linux":
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
targets_to_scan = [str(p) for p in possible_paths]
else:
- targets_to_scan = list(range(MAX_OPENCV_INDEX))
+ targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
for target in targets_to_scan:
camera = cv2.VideoCapture(target)
@@ -271,6 +314,12 @@ class OpenCVCamera(Camera):
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
default_fps = camera.get(cv2.CAP_PROP_FPS)
default_format = camera.get(cv2.CAP_PROP_FORMAT)
+
+ # Get FOURCC code and convert to string
+ default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
+ default_fourcc_code_int = int(default_fourcc_code)
+ default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
+
camera_info = {
"name": f"OpenCV Camera @ {target}",
"type": "OpenCV",
@@ -278,6 +327,7 @@ class OpenCVCamera(Camera):
"backend_api": camera.getBackendName(),
"default_stream_profile": {
"format": default_format,
+ "fourcc": default_fourcc,
"width": default_width,
"height": default_height,
"fps": default_fps,
@@ -289,7 +339,7 @@ class OpenCVCamera(Camera):
return found_cameras_info
- def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
+ def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -317,6 +367,9 @@ class OpenCVCamera(Camera):
start_time = time.perf_counter()
+ if self.videocapture is None:
+ raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
+
ret, frame = self.videocapture.read()
if not ret or frame is None:
@@ -329,7 +382,7 @@ class OpenCVCamera(Camera):
return processed_frame
- def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
+ def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
@@ -372,7 +425,7 @@ class OpenCVCamera(Camera):
return processed_image
- def _read_loop(self):
+ def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
@@ -383,6 +436,9 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
+ if self.stop_event is None:
+ raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
+
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -419,7 +475,7 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
- def async_read(self, timeout_ms: float = 200) -> np.ndarray:
+ def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -462,7 +518,7 @@ class OpenCVCamera(Camera):
return frame
- def disconnect(self):
+ def disconnect(self) -> None:
"""
Disconnects from the camera and cleans up resources.
diff --git a/src/lerobot/cameras/opencv/configuration_opencv.py b/src/lerobot/cameras/opencv/configuration_opencv.py
index 3ac92de36..37a42861c 100644
--- a/src/lerobot/cameras/opencv/configuration_opencv.py
+++ b/src/lerobot/cameras/opencv/configuration_opencv.py
@@ -17,6 +17,8 @@ from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
+__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
+
@CameraConfig.register_subclass("opencv")
@dataclass
@@ -33,8 +35,9 @@ class OpenCVCameraConfig(CameraConfig):
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
- # Advanced configurations
- OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
+ # Advanced configurations with FOURCC format
+ OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
+ OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
```
Attributes:
@@ -46,17 +49,21 @@ class OpenCVCameraConfig(CameraConfig):
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
+ fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
+ - FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
+ - Setting FOURCC can help achieve higher frame rates on some cameras.
"""
index_or_path: int | Path
color_mode: ColorMode = ColorMode.RGB
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
+ fourcc: str | None = None
- def __post_init__(self):
+ def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
@@ -71,3 +78,8 @@ class OpenCVCameraConfig(CameraConfig):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
+
+ if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
+ raise ValueError(
+ f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
+ )
diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py
index 5b2303ff2..f26cf2ad1 100644
--- a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py
+++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py
@@ -16,6 +16,8 @@ from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
+__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
+
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
@@ -62,7 +64,7 @@ class Reachy2CameraConfig(CameraConfig):
port: int = 50065
# use_depth: bool = False
- def __post_init__(self):
+ def __post_init__(self) -> None:
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py
index c96789f96..30e096767 100644
--- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py
+++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py
@@ -23,13 +23,17 @@ import time
from threading import Event, Lock, Thread
from typing import Any
+from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
+
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
-import cv2
-import numpy as np
-from reachy2_sdk.media.camera import CameraView
-from reachy2_sdk.media.camera_manager import CameraManager
+import cv2 # type: ignore # TODO: add type stubs for OpenCV
+import numpy as np # type: ignore # TODO: add type stubs for numpy
+from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
+from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
+ CameraManager,
+)
from lerobot.utils.errors import DeviceNotConnectedError
@@ -73,7 +77,7 @@ class Reachy2Camera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
- self.latest_frame: np.ndarray | None = None
+ self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -83,13 +87,17 @@ class Reachy2Camera(Camera):
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
- return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
+ return bool(
+ self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
+ )
elif self.config.name == "depth":
- return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
+ return bool(
+ self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
+ )
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
- def connect(self, warmup: bool = True):
+ def connect(self, warmup: bool = True) -> None:
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
@@ -131,7 +139,7 @@ class Reachy2Camera(Camera):
camera_manager.disconnect()
return initialized_cameras
- def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
+ def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -152,7 +160,7 @@ class Reachy2Camera(Camera):
start_time = time.perf_counter()
- frame = None
+ frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -179,7 +187,7 @@ class Reachy2Camera(Camera):
return frame
- def _read_loop(self):
+ def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
@@ -190,6 +198,9 @@ class Reachy2Camera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
+ if self.stop_event is None:
+ raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
+
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -226,7 +237,7 @@ class Reachy2Camera(Camera):
self.thread = None
self.stop_event = None
- def async_read(self, timeout_ms: float = 200) -> np.ndarray:
+ def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -269,7 +280,7 @@ class Reachy2Camera(Camera):
return frame
- def disconnect(self):
+ def disconnect(self) -> None:
"""
Stops the background read thread (if running).
diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py
index cc816e552..f2906fdd8 100644
--- a/src/lerobot/cameras/realsense/camera_realsense.py
+++ b/src/lerobot/cameras/realsense/camera_realsense.py
@@ -21,11 +21,12 @@ import time
from threading import Event, Lock, Thread
from typing import Any
-import cv2
-import numpy as np
+import cv2 # type: ignore # TODO: add type stubs for OpenCV
+import numpy as np # type: ignore # TODO: add type stubs for numpy
+from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
try:
- import pyrealsense2 as rs
+ import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
except Exception as e:
logging.info(f"Could not import realsense: {e}")
@@ -132,7 +133,7 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
- self.latest_frame: np.ndarray | None = None
+ self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -150,7 +151,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
- def connect(self, warmup: bool = True):
+ def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -264,7 +265,7 @@ class RealSenseCamera(Camera):
serial_number = str(found_devices[0]["serial_number"])
return serial_number
- def _configure_rs_pipeline_config(self, rs_config):
+ def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
"""Creates and configures the RealSense pipeline configuration object."""
rs.config.enable_device(rs_config, self.serial_number)
@@ -293,6 +294,9 @@ class RealSenseCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
+ if self.rs_profile is None:
+ raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
+
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
if self.fps is None:
@@ -308,7 +312,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
- def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
+ def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -336,6 +340,9 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
+ if self.rs_pipeline is None:
+ raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
+
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -351,7 +358,7 @@ class RealSenseCamera(Camera):
return depth_map_processed
- def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
+ def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
@@ -376,6 +383,9 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
+ if self.rs_pipeline is None:
+ raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
+
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -392,8 +402,8 @@ class RealSenseCamera(Camera):
return color_image_processed
def _postprocess_image(
- self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
- ) -> np.ndarray:
+ self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
+ ) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
@@ -438,7 +448,7 @@ class RealSenseCamera(Camera):
return processed_image
- def _read_loop(self):
+ def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
@@ -449,6 +459,9 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
+ if self.stop_event is None:
+ raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
+
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
@@ -474,7 +487,7 @@ class RealSenseCamera(Camera):
self.thread.daemon = True
self.thread.start()
- def _stop_read_thread(self):
+ def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
@@ -486,7 +499,7 @@ class RealSenseCamera(Camera):
self.stop_event = None
# NOTE(Steven): Missing implementation for depth for now
- def async_read(self, timeout_ms: float = 200) -> np.ndarray:
+ def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -529,7 +542,7 @@ class RealSenseCamera(Camera):
return frame
- def disconnect(self):
+ def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py
index 36a86876d..a094128bc 100644
--- a/src/lerobot/cameras/realsense/configuration_realsense.py
+++ b/src/lerobot/cameras/realsense/configuration_realsense.py
@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
- def __post_init__(self):
+ def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py
index aa6ff98b4..1b2d386d6 100644
--- a/src/lerobot/cameras/utils.py
+++ b/src/lerobot/cameras/utils.py
@@ -53,14 +53,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
- import cv2
+ import cv2 # type: ignore # TODO: add type stubs for OpenCV
if rotation == Cv2Rotation.ROTATE_90:
- return cv2.ROTATE_90_CLOCKWISE
+ return int(cv2.ROTATE_90_CLOCKWISE)
elif rotation == Cv2Rotation.ROTATE_180:
- return cv2.ROTATE_180
+ return int(cv2.ROTATE_180)
elif rotation == Cv2Rotation.ROTATE_270:
- return cv2.ROTATE_90_COUNTERCLOCKWISE
+ return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
@@ -69,8 +69,8 @@ def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
- return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
+ return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
- return cv2.CAP_ANY
+ return int(cv2.CAP_ANY)
diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py
index c72438ca7..f613b5251 100644
--- a/src/lerobot/configs/default.py
+++ b/src/lerobot/configs/default.py
@@ -57,7 +57,7 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
- def __post_init__(self):
+ def __post_init__(self) -> None:
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py
index cfe48cf87..2f085da56 100644
--- a/src/lerobot/configs/eval.py
+++ b/src/lerobot/configs/eval.py
@@ -13,8 +13,8 @@
# limitations under the License.
import datetime as dt
-import logging
from dataclasses import dataclass, field
+from logging import getLogger
from pathlib import Path
from lerobot import envs, policies # noqa: F401
@@ -22,6 +22,8 @@ from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
+logger = getLogger(__name__)
+
@dataclass
class EvalPipelineConfig:
@@ -34,25 +36,31 @@ class EvalPipelineConfig:
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
+ # Rename map for the observation to override the image and state keys
+ rename_map: dict[str, str] = field(default_factory=dict)
- def __post_init__(self):
+ def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
- self.policy.pretrained_path = policy_path
+ self.policy.pretrained_path = Path(policy_path)
else:
- logging.warning(
+ logger.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
if not self.job_name:
if self.env is None:
- self.job_name = f"{self.policy.type}"
+ self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
else:
- self.job_name = f"{self.env.type}_{self.policy.type}"
+ self.job_name = (
+ f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
+ )
+
+ logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
if not self.output_dir:
now = dt.datetime.now()
diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py
index 2296eaa20..57ebaf8fa 100644
--- a/src/lerobot/configs/parser.py
+++ b/src/lerobot/configs/parser.py
@@ -16,14 +16,19 @@ import inspect
import pkgutil
import sys
from argparse import ArgumentError
-from collections.abc import Sequence
+from collections.abc import Callable, Iterable, Sequence
from functools import wraps
from pathlib import Path
+from pkgutil import ModuleInfo
+from types import ModuleType
+from typing import Any, TypeVar, cast
import draccus
from lerobot.utils.utils import has_method
+F = TypeVar("F", bound=Callable[..., object])
+
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None
-def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
+def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
"""Parse plugin-related arguments from command-line arguments.
This function extracts arguments from command-line arguments that match a specified suffix pattern.
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e
- def iter_namespace(ns_pkg):
+ def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
try:
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
+ if args is None:
+ return []
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter]
- filtered_args = args
+ filtered_args = [] if args is None else list(args)
+
for field in fields_to_filter:
if get_path_arg(field, args):
if get_type_arg(field, args):
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args
-def wrap(config_path: Path | None = None):
+def wrap(config_path: Path | None = None) -> Callable[[F], F]:
"""
HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
from the CLI '.type' arguments
"""
- def wrapper_outer(fn):
+ def wrapper_outer(fn: F) -> F:
@wraps(fn)
- def wrapper_inner(*args, **kwargs):
+ def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype:
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
response = fn(cfg, *args, **kwargs)
return response
- return wrapper_inner
+ return cast(F, wrapper_inner)
- return wrapper_outer
+ return cast(Callable[[F], F], wrapper_outer)
diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py
index 76810ac72..df1559ac2 100644
--- a/src/lerobot/configs/policies.py
+++ b/src/lerobot/configs/policies.py
@@ -14,12 +14,12 @@
import abc
import builtins
import json
-import logging
import os
import tempfile
from dataclasses import dataclass, field
+from logging import getLogger
from pathlib import Path
-from typing import TypeVar
+from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig")
+logger = getLogger(__name__)
@dataclass
-class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
+class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
"""
Base configuration class for policy models.
@@ -57,7 +58,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
- device: str | None = None # cuda | cpu | mp
+ device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
@@ -65,7 +66,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
# Whether the policy employed PEFT for training.
use_peft: bool = False
- push_to_hub: bool = True
+ push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
@@ -76,38 +77,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
- pretrained_path: str | None = None
+ pretrained_path: Path | None = None
- def __post_init__(self):
+ def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
- logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
+ logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
- logging.warning(
+ logger.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
- return self.get_choice_name(self.__class__)
+ choice_name = self.get_choice_name(self.__class__)
+ if not isinstance(choice_name, str):
+ raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
+ return choice_name
@property
@abc.abstractmethod
- def observation_delta_indices(self) -> list | None:
+ def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
@abc.abstractmethod
- def action_delta_indices(self) -> list | None:
+ def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
@abc.abstractmethod
- def reward_delta_indices(self) -> list | None:
+ def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@abc.abstractmethod
@@ -157,13 +161,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
- resume_download: bool = None,
- proxies: dict | None = None,
+ resume_download: bool | None = None,
+ proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
- **policy_kwargs,
+ **policy_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -171,7 +175,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
- print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
+ logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
@@ -197,6 +201,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
+ if config_file is None:
+ raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
+
with open(config_file) as f:
config = json.load(f)
diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py
index 546c4dbe0..c64b931be 100644
--- a/src/lerobot/configs/train.py
+++ b/src/lerobot/configs/train.py
@@ -16,6 +16,7 @@ import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
+from typing import Any
import draccus
from huggingface_hub import hf_hub_download
@@ -64,18 +65,18 @@ class TrainPipelineConfig(HubMixin):
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
+ checkpoint_path: Path | None = field(init=False, default=None)
+ # Rename map for the observation to override the image and state keys
+ rename_map: dict[str, str] = field(default_factory=dict)
- def __post_init__(self):
- self.checkpoint_path = None
-
- def validate(self):
+ def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
- self.policy.pretrained_path = policy_path
+ self.policy.pretrained_path = Path(policy_path)
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
@@ -83,14 +84,22 @@ class TrainPipelineConfig(HubMixin):
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
+
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
- policy_path = Path(config_path).parent
- self.policy.pretrained_path = policy_path
- self.checkpoint_path = policy_path.parent
+
+ policy_dir = Path(config_path).parent
+ if self.policy is not None:
+ self.policy.pretrained_path = policy_dir
+ self.checkpoint_path = policy_dir.parent
+
+ if self.policy is None:
+ raise ValueError(
+ "Policy is not configured. Please specify a pretrained policy with `--policy.path`."
+ )
if not self.job_name:
if self.env is None:
@@ -127,8 +136,8 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
- def to_dict(self) -> dict:
- return draccus.encode(self)
+ def to_dict(self) -> dict[str, Any]:
+ return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
@@ -140,13 +149,13 @@ class TrainPipelineConfig(HubMixin):
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
- resume_download: bool = None,
- proxies: dict | None = None,
+ resume_download: bool | None = None,
+ proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
- **kwargs,
+ **kwargs: Any,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -182,4 +191,6 @@ class TrainPipelineConfig(HubMixin):
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
- dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
+ # NOTE: In RL, we don't need an offline dataset
+ # TODO: Make `TrainPipelineConfig.dataset` optional
+ dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py
index cb578060e..18359ef05 100644
--- a/src/lerobot/configs/types.py
+++ b/src/lerobot/configs/types.py
@@ -42,4 +42,11 @@ class NormalizationMode(str, Enum):
@dataclass
class PolicyFeature:
type: FeatureType
- shape: tuple
+ shape: tuple[int, ...]
+
+
+class RTCAttentionSchedule(str, Enum):
+ ZEROS = "ZEROS"
+ ONES = "ONES"
+ LINEAR = "LINEAR"
+ EXP = "EXP"
diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py
index 2735ba0a0..2fb68dca1 100644
--- a/src/lerobot/datasets/dataset_tools.py
+++ b/src/lerobot/datasets/dataset_tools.py
@@ -39,6 +39,7 @@ from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
+ DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -962,28 +963,23 @@ def _copy_data_with_feature_changes(
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
- if dataset.meta.episodes is None:
- dataset.meta.episodes = load_episodes(dataset.meta.root)
+ data_dir = dataset.root / DATA_DIR
+ parquet_files = sorted(data_dir.glob("*/*.parquet"))
- # Map file paths to episode indices to extract chunk/file indices
- file_to_episodes: dict[Path, set[int]] = {}
- for ep_idx in range(dataset.meta.total_episodes):
- file_path = dataset.meta.get_data_file_path(ep_idx)
- if file_path not in file_to_episodes:
- file_to_episodes[file_path] = set()
- file_to_episodes[file_path].add(ep_idx)
+ if not parquet_files:
+ raise ValueError(f"No parquet files found in {data_dir}")
frame_idx = 0
- for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
- df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
+ for src_path in tqdm(parquet_files, desc="Processing data files"):
+ df = pd.read_parquet(src_path).reset_index(drop=True)
- # Get chunk_idx and file_idx from the source file's first episode
- episodes_in_file = file_to_episodes[src_path]
- first_ep_idx = min(episodes_in_file)
- src_ep = dataset.meta.episodes[first_ep_idx]
- chunk_idx = src_ep["data/chunk_index"]
- file_idx = src_ep["data/file_index"]
+ relative_path = src_path.relative_to(dataset.root)
+ chunk_dir = relative_path.parts[1]
+ file_name = relative_path.parts[2]
+
+ chunk_idx = int(chunk_dir.split("-")[1])
+ file_idx = int(file_name.split("-")[1].split(".")[0])
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
@@ -1009,7 +1005,7 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_slice
frame_idx = end_idx
- # Write using the preserved chunk_idx and file_idx from source
+ # Write using the same chunk/file structure as source
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py
index ae142c1e8..9c94235c9 100644
--- a/src/lerobot/datasets/lerobot_dataset.py
+++ b/src/lerobot/datasets/lerobot_dataset.py
@@ -430,9 +430,7 @@ class LeRobotDatasetMetadata:
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
- video_path = self.root / self.video_path.format(
- video_key=video_key, chunk_index=0, file_index=0
- )
+ video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
@@ -686,6 +684,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = None
self.writer = None
self.latest_episode = None
+ self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self.root.mkdir(exist_ok=True, parents=True)
@@ -708,10 +707,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
- self.revision = get_safe_version(self.repo_id, self.revision)
+ if is_valid_version(self.revision):
+ self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
+ # Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
+ # Build a mapping: absolute_index -> relative_index_in_filtered_dataset
+ self._absolute_to_relative_idx = None
+ if self.episodes is not None:
+ self._absolute_to_relative_idx = {
+ abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
+ for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
+ }
+
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -830,31 +839,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
- hf_dataset = load_nested_dataset(self.root / "data", features=features)
+ hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
- """Check if the cached dataset contains all requested episodes."""
+ """Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
# Get available episode indices from cached dataset
available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
- for ep_idx in self.hf_dataset["episode_index"]
+ for ep_idx in self.hf_dataset.unique("episode_index")
}
# Determine requested episodes
if self.episodes is None:
- # Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
- # Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
- return requested_episodes.issubset(available_episodes)
+ if not requested_episodes.issubset(available_episodes):
+ return False
+
+ # Check if all required video files exist
+ if len(self.meta.video_keys) > 0:
+ for ep_idx in requested_episodes:
+ for vid_key in self.meta.video_keys:
+ video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
+ if not video_path.exists():
+ return False
+
+ return True
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
@@ -921,7 +939,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
- timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
+ if self._absolute_to_relative_idx is not None:
+ relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
+ timestamps = self.hf_dataset[relative_indices]["timestamp"]
+ else:
+ timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -929,11 +951,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
- return {
- key: torch.stack(self.hf_dataset[q_idx][key])
- for key, q_idx in query_indices.items()
- if key not in self.meta.video_keys
- }
+ """
+ Query dataset for indices across keys, skipping video keys.
+
+ Tries column-first [key][indices] for speed, falls back to row-first.
+
+ Args:
+ query_indices: Dict mapping keys to index lists to retrieve
+
+ Returns:
+ Dict with stacked tensors of queried data (video keys excluded)
+ """
+ result: dict = {}
+ for key, q_idx in query_indices.items():
+ if key in self.meta.video_keys:
+ continue
+ # Map absolute indices to relative indices if needed
+ relative_indices = (
+ q_idx
+ if self._absolute_to_relative_idx is None
+ else [self._absolute_to_relative_idx[idx] for idx in q_idx]
+ )
+ try:
+ result[key] = torch.stack(self.hf_dataset[key][relative_indices])
+ except (KeyError, TypeError, IndexError):
+ result[key] = torch.stack(self.hf_dataset[relative_indices][key])
+ return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -1231,6 +1274,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
global_frame_index = 0
+ self._current_file_start_frame = 0
# However, if the episodes already exists
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
@@ -1242,6 +1286,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
+ self._current_file_start_frame = global_frame_index
else:
# Retrieve information from the latest parquet file
latest_ep = self.latest_episode
@@ -1252,7 +1297,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_file_size_in_mb(latest_path)
- frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
+ frames_in_current_file = global_frame_index - self._current_file_start_frame
av_size_per_frame = (
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
)
@@ -1266,6 +1311,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._close_writer()
self._writer_closed_for_reading = False
+ self._current_file_start_frame = global_frame_index
ep_dict["data/chunk_index"] = chunk_idx
ep_dict["data/file_index"] = file_idx
@@ -1469,9 +1515,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
+ obj._absolute_to_relative_idx = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
+ obj._current_file_start_frame = None
# Initialize tracking for incremental recording
obj._lazy_loading = False
obj._recorded_frames = 0
diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py
index f7072c72f..beacc48d9 100644
--- a/src/lerobot/datasets/transforms.py
+++ b/src/lerobot/datasets/transforms.py
@@ -206,6 +206,11 @@ class ImageTransformsConfig:
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
+ "affine": ImageTransformConfig(
+ weight=1.0,
+ type="RandomAffine",
+ kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
+ ),
}
)
@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
+ elif cfg.type == "RandomAffine":
+ return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py
index 37d8432b2..ce4cf1da1 100644
--- a/src/lerobot/datasets/utils.py
+++ b/src/lerobot/datasets/utils.py
@@ -28,6 +28,7 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
+import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
-def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
+def load_nested_dataset(
+ pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
+) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
+ episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
- # TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
- datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
- return datasets
+ # When no filtering needed, Dataset uses memory-mapped loading for efficiency
+ # PyArrow loads the entire dataset into memory
+ if episodes is None:
+ return Dataset.from_parquet([str(path) for path in paths], features=features)
+
+ arrow_dataset = pa_ds.dataset(paths, format="parquet")
+ filter_expr = pa_ds.field("episode_index").isin(episodes)
+ table = arrow_dataset.to_table(filter=filter_expr)
+
+ if features is not None:
+ table = table.cast(features.arrow_schema)
+
+ return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
index b8ae29ad6..74be6bfa4 100644
--- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
+++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
@@ -98,7 +98,7 @@ OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
-videos/chunk-000/file_000.mp4
+videos/CAMERA/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl
diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py
index 740cdb602..0de791919 100644
--- a/src/lerobot/datasets/video_utils.py
+++ b/src/lerobot/datasets/video_utils.py
@@ -342,8 +342,8 @@ def encode_video_frames(
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
- dummy_image = Image.open(input_list[0])
- width, height = dummy_image.size
+ with Image.open(input_list[0]) as dummy_image:
+ width, height = dummy_image.size
# Define video codec options
video_options = {}
@@ -373,11 +373,12 @@ def encode_video_frames(
# Loop through input frames and encode them
for input_data in input_list:
- input_image = Image.open(input_data).convert("RGB")
- input_frame = av.VideoFrame.from_image(input_image)
- packet = output_stream.encode(input_frame)
- if packet:
- output.mux(packet)
+ with Image.open(input_data) as input_image:
+ input_image = input_image.convert("RGB")
+ input_frame = av.VideoFrame.from_image(input_image)
+ packet = output_stream.encode(input_frame)
+ if packet:
+ output.mux(packet)
# Flush the encoder
packet = output_stream.encode()
diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py
index 3aa155093..e696f683e 100644
--- a/src/lerobot/envs/configs.py
+++ b/src/lerobot/envs/configs.py
@@ -21,7 +21,22 @@ import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
-from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
+from lerobot.utils.constants import (
+ ACTION,
+ LIBERO_KEY_EEF_MAT,
+ LIBERO_KEY_EEF_POS,
+ LIBERO_KEY_EEF_QUAT,
+ LIBERO_KEY_GRIPPER_QPOS,
+ LIBERO_KEY_GRIPPER_QVEL,
+ LIBERO_KEY_JOINTS_POS,
+ LIBERO_KEY_JOINTS_VEL,
+ LIBERO_KEY_PIXELS_AGENTVIEW,
+ LIBERO_KEY_PIXELS_EYE_IN_HAND,
+ OBS_ENV_STATE,
+ OBS_IMAGE,
+ OBS_IMAGES,
+ OBS_STATE,
+)
@dataclass
@@ -37,6 +52,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str:
return self.get_choice_name(self.__class__)
+ @property
+ def package_name(self) -> str:
+ """Package name to import if environment not found in gym registry"""
+ return f"gym_{self.type}"
+
+ @property
+ def gym_id(self) -> str:
+ """ID string used in gym.make() to instantiate the environment"""
+ return f"{self.package_name}/{self.task}"
+
@property
@abc.abstractmethod
def gym_kwargs(self) -> dict:
@@ -236,28 +261,61 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
- "agent_pos": OBS_STATE,
- "pixels/agentview_image": f"{OBS_IMAGES}.image",
- "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
+ LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
+ LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
+ LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
+ LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
+ LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
+ LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
+ LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
+ LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
+ LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
- self.features["pixels/agentview_image"] = PolicyFeature(
+ self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
- self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
+ self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
- self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
- self.features["pixels/agentview_image"] = PolicyFeature(
+ self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
- self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
+ self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
+ self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(3,),
+ )
+ self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(4,),
+ )
+ self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(3, 3),
+ )
+ self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(2,),
+ )
+ self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(2,),
+ )
+ self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(7,),
+ )
+ self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(7,),
+ )
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py
index 059e0e11a..bda84fd78 100644
--- a/src/lerobot/envs/factory.py
+++ b/src/lerobot/envs/factory.py
@@ -14,10 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
+from typing import Any
import gymnasium as gym
+from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
+from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
+from lerobot.processor import ProcessorStep
+from lerobot.processor.env_processor import LiberoProcessorStep
+from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -31,16 +37,60 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
-def make_env(
- cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
-) -> dict[str, dict[int, gym.vector.VectorEnv]]:
- """Makes a gym vector environment according to the config.
+def make_env_pre_post_processors(
+ env_cfg: EnvConfig,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+]:
+ """
+ Create preprocessor and postprocessor pipelines for environment observations.
+
+ This function creates processor pipelines that transform raw environment
+ observations and actions. By default, it returns identity processors that do nothing.
+ For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
- cfg (EnvConfig): the config of the environment to instantiate.
+ env_cfg: The configuration of the environment.
+
+ Returns:
+ A tuple containing:
+ - preprocessor: Pipeline that processes environment observations
+ - postprocessor: Pipeline that processes environment outputs (currently identity)
+ """
+ # Preprocessor and Postprocessor steps are Identity for most environments
+ preprocessor_steps: list[ProcessorStep] = []
+ postprocessor_steps: list[ProcessorStep] = []
+
+ # For LIBERO environments, add the LiberoProcessorStep to preprocessor
+ if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
+ preprocessor_steps.append(LiberoProcessorStep())
+
+ preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
+ postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
+
+ return preprocessor, postprocessor
+
+
+def make_env(
+ cfg: EnvConfig | str,
+ n_envs: int = 1,
+ use_async_envs: bool = False,
+ hub_cache_dir: str | None = None,
+ trust_remote_code: bool = False,
+) -> dict[str, dict[int, gym.vector.VectorEnv]]:
+ """Makes a gym vector environment according to the config or Hub reference.
+
+ Args:
+ cfg (EnvConfig | str): Either an `EnvConfig` object describing the environment to build locally,
+ or a Hugging Face Hub repository identifier (e.g. `"username/repo"`). In the latter case,
+ the repo must include a Python file (usually `env.py`).
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
+ hub_cache_dir (str | None): Optional cache path for downloaded hub files.
+ trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
+ Default False — must be set to True to import/exec hub `env.py`.
Raises:
ValueError: if n_envs < 1
@@ -53,6 +103,21 @@ def make_env(
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
"""
+ # if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
+ # simplified: only support hub-provided `make_env`
+ if isinstance(cfg, str):
+ # _download_hub_file will raise the same RuntimeError if trust_remote_code is False
+ repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
+
+ # import and surface clear import errors
+ module = _import_hub_module(local_file, repo_id)
+
+ # call the hub-provided make_env
+ raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
+
+ # normalize the return into {suite: {task_id: vec_env}}
+ return _normalize_hub_result(raw_result)
+
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
@@ -84,17 +149,24 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
- package_name = f"gym_{cfg.type}"
- try:
- importlib.import_module(package_name)
- except ModuleNotFoundError as e:
- print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
- raise e
- gym_handle = f"{package_name}/{cfg.task}"
+ if cfg.gym_id not in gym_registry:
+ print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
+ try:
+ importlib.import_module(cfg.package_name)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
+ f"Please install it or check PYTHONPATH."
+ ) from e
+
+ if cfg.gym_id not in gym_registry:
+ raise gym.error.NameNotFound(
+ f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
+ )
def _make_one():
- return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
+ return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py
index 94b08e991..35bc58e07 100644
--- a/src/lerobot/envs/libero.py
+++ b/src/lerobot/envs/libero.py
@@ -28,7 +28,6 @@ import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
-from robosuite.utils.transform_utils import quat2axisangle
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
@@ -175,11 +174,36 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
- "agent_pos": spaces.Box(
- low=AGENT_POS_LOW,
- high=AGENT_POS_HIGH,
- shape=(OBS_STATE_DIM,),
- dtype=np.float64,
+ "robot_state": spaces.Dict(
+ {
+ "eef": spaces.Dict(
+ {
+ "pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
+ "quat": spaces.Box(
+ low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
+ ),
+ "mat": spaces.Box(
+ low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
+ ),
+ }
+ ),
+ "gripper": spaces.Dict(
+ {
+ "qpos": spaces.Box(
+ low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
+ ),
+ "qvel": spaces.Box(
+ low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
+ ),
+ }
+ ),
+ "joints": spaces.Dict(
+ {
+ "pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
+ "vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
+ }
+ ),
+ }
),
}
)
@@ -191,6 +215,7 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
+ image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -212,23 +237,48 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
- image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
- state = np.concatenate(
- (
- raw_obs["robot0_eef_pos"],
- quat2axisangle(raw_obs["robot0_eef_quat"]),
- raw_obs["robot0_gripper_qpos"],
- )
- )
- agent_pos = state
+
+ eef_pos = raw_obs.get("robot0_eef_pos")
+ eef_quat = raw_obs.get("robot0_eef_quat")
+
+ # rotation matrix from controller
+ eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
+ gripper_qpos = raw_obs.get("robot0_gripper_qpos")
+ gripper_qvel = raw_obs.get("robot0_gripper_qvel")
+ joint_pos = raw_obs.get("robot0_joint_pos")
+ joint_vel = raw_obs.get("robot0_joint_vel")
+ obs = {
+ "pixels": images,
+ "robot_state": {
+ "eef": {
+ "pos": eef_pos, # (3,)
+ "quat": eef_quat, # (4,)
+ "mat": eef_mat, # (3, 3)
+ },
+ "gripper": {
+ "qpos": gripper_qpos, # (2,)
+ "qvel": gripper_qvel, # (2,)
+ },
+ "joints": {
+ "pos": joint_pos, # (7,)
+ "vel": joint_vel, # (7,)
+ },
+ },
+ }
if self.obs_type == "pixels":
return {"pixels": images.copy()}
+
if self.obs_type == "pixels_agent_pos":
- return {
- "pixels": images.copy(),
- "agent_pos": agent_pos,
- }
+ # Validate required fields are present
+ if eef_pos is None or eef_quat is None or gripper_qpos is None:
+ raise ValueError(
+ f"Missing required robot state fields in raw observation. "
+ f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
+ f"gripper_qpos={gripper_qpos is not None}"
+ )
+ return obs
+
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -355,12 +405,10 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
-
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
-
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py
index 5584e0bff..8d0f24922 100644
--- a/src/lerobot/envs/utils.py
+++ b/src/lerobot/envs/utils.py
@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib.util
+import os
import warnings
from collections.abc import Mapping, Sequence
from functools import singledispatch
@@ -22,14 +24,27 @@ import einops
import gymnasium as gym
import numpy as np
import torch
+from huggingface_hub import hf_hub_download, snapshot_download
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
-from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
+from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
+def _convert_nested_dict(d):
+ result = {}
+ for k, v in d.items():
+ if isinstance(v, dict):
+ result[k] = _convert_nested_dict(v)
+ elif isinstance(v, np.ndarray):
+ result[k] = torch.from_numpy(v)
+ else:
+ result[k] = v
+ return result
+
+
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -75,12 +90,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
- # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
- agent_pos = torch.from_numpy(observations["agent_pos"]).float()
- if agent_pos.dim() == 1:
- agent_pos = agent_pos.unsqueeze(0)
- return_observations[OBS_STATE] = agent_pos
+ if "agent_pos" in observations:
+ agent_pos = torch.from_numpy(observations["agent_pos"]).float()
+ if agent_pos.dim() == 1:
+ agent_pos = agent_pos.unsqueeze(0)
+ return_observations[OBS_STATE] = agent_pos
+ if "robot_state" in observations:
+ return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations
@@ -195,3 +212,132 @@ def _(envs: Sequence) -> None:
@close_envs.register
def _(env: gym.Env) -> None:
_close_single_env(env)
+
+
+# helper to safely load a python file as a module
+def _load_module_from_path(path: str, module_name: str | None = None):
+ module_name = module_name or f"hub_env_{os.path.basename(path).replace('.', '_')}"
+ spec = importlib.util.spec_from_file_location(module_name, path)
+ if spec is None:
+ raise ImportError(f"Could not load module spec for {module_name} from {path}")
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module) # type: ignore
+ return module
+
+
+# helper to parse hub string (supports "user/repo", "user/repo@rev", optional path)
+# examples:
+# "user/repo" -> will look for env.py at repo root
+# "user/repo@main:envs/my_env.py" -> explicit revision and path
+def _parse_hub_url(hub_uri: str):
+ # very small parser: [repo_id][@revision][:path]
+ # repo_id is required (user/repo or org/repo)
+ revision = None
+ file_path = "env.py"
+ if "@" in hub_uri:
+ repo_and_rev, *rest = hub_uri.split(":", 1)
+ repo_id, rev = repo_and_rev.split("@", 1)
+ revision = rev
+ if rest:
+ file_path = rest[0]
+ else:
+ repo_id, *rest = hub_uri.split(":", 1)
+ if rest:
+ file_path = rest[0]
+ return repo_id, revision, file_path
+
+
+def _download_hub_file(
+ cfg_str: str,
+ trust_remote_code: bool,
+ hub_cache_dir: str | None,
+) -> tuple[str, str, str, str]:
+ """
+ Parse `cfg_str` (hub URL), enforce `trust_remote_code`, and return
+ (repo_id, file_path, local_file, revision).
+ """
+ if not trust_remote_code:
+ raise RuntimeError(
+ f"Refusing to execute remote code from the Hub for '{cfg_str}'. "
+ "Executing hub env modules runs arbitrary Python code from third-party repositories. "
+ "If you trust this repo and understand the risks, call `make_env(..., trust_remote_code=True)` "
+ "and prefer pinning to a specific revision: 'user/repo@:env.py'."
+ )
+
+ repo_id, revision, file_path = _parse_hub_url(cfg_str)
+
+ try:
+ local_file = hf_hub_download(
+ repo_id=repo_id, filename=file_path, revision=revision, cache_dir=hub_cache_dir
+ )
+ except Exception as e:
+ # fallback to snapshot download
+ snapshot_dir = snapshot_download(repo_id=repo_id, revision=revision, cache_dir=hub_cache_dir)
+ local_file = os.path.join(snapshot_dir, file_path)
+ if not os.path.exists(local_file):
+ raise FileNotFoundError(
+ f"Could not find {file_path} in repository {repo_id}@{revision or 'main'}"
+ ) from e
+
+ return repo_id, file_path, local_file, revision
+
+
+def _import_hub_module(local_file: str, repo_id: str) -> Any:
+ """
+ Import the downloaded file as a module and surface helpful import error messages.
+ """
+ module_name = f"hub_env_{repo_id.replace('/', '_')}"
+ try:
+ module = _load_module_from_path(local_file, module_name=module_name)
+ except ModuleNotFoundError as e:
+ missing = getattr(e, "name", None) or str(e)
+ raise ModuleNotFoundError(
+ f"Hub env '{repo_id}:{os.path.basename(local_file)}' failed to import because the dependency "
+ f"'{missing}' is not installed locally.\n\n"
+ ) from e
+ except ImportError as e:
+ raise ImportError(
+ f"Failed to load hub env module '{repo_id}:{os.path.basename(local_file)}'. Import error: {e}\n\n"
+ ) from e
+ return module
+
+
+def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
+ """
+ Ensure module exposes make_env and call it.
+ """
+ if not hasattr(module, "make_env"):
+ raise AttributeError(
+ f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
+ )
+ entry_fn = module.make_env
+ return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
+
+
+def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
+ """
+ Normalize possible return types from hub `make_env` into the mapping:
+ { suite_name: { task_id: vector_env } }
+ Accepts:
+ - dict (assumed already correct)
+ - gym.vector.VectorEnv
+ - gym.Env (will be wrapped into SyncVectorEnv)
+ """
+ if isinstance(result, dict):
+ return result
+
+ # VectorEnv: use its spec.id if available
+ if isinstance(result, gym.vector.VectorEnv):
+ suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
+ return {suite_name: {0: result}}
+
+ # Single Env: wrap into SyncVectorEnv
+ if isinstance(result, gym.Env):
+ vec = gym.vector.SyncVectorEnv([lambda: result])
+ suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
+ return {suite_name: {0: vec}}
+
+ raise ValueError(
+ "Hub `make_env` must return either a mapping {suite: {task_id: vec_env}}, "
+ "a gym.vector.VectorEnv, or a single gym.Env."
+ )
diff --git a/src/lerobot/model/kinematics.py b/src/lerobot/model/kinematics.py
index f059b9790..95d3b235c 100644
--- a/src/lerobot/model/kinematics.py
+++ b/src/lerobot/model/kinematics.py
@@ -22,18 +22,18 @@ class RobotKinematics:
self,
urdf_path: str,
target_frame_name: str = "gripper_frame_link",
- joint_names: list[str] = None,
+ joint_names: list[str] | None = None,
):
"""
Initialize placo-based kinematics solver.
Args:
- urdf_path: Path to the robot URDF file
- target_frame_name: Name of the end-effector frame in the URDF
- joint_names: List of joint names to use for the kinematics solver
+ urdf_path (str): Path to the robot URDF file
+ target_frame_name (str): Name of the end-effector frame in the URDF
+ joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
try:
- import placo
+ import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
except ImportError as e:
raise ImportError(
"placo is required for RobotKinematics. "
@@ -52,7 +52,7 @@ class RobotKinematics:
# Initialize frame task for IK
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
- def forward_kinematics(self, joint_pos_deg):
+ def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
"""
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
@@ -77,8 +77,12 @@ class RobotKinematics:
return self.robot.get_T_world_frame(self.target_frame_name)
def inverse_kinematics(
- self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
- ):
+ self,
+ current_joint_pos: np.ndarray,
+ desired_ee_pose: np.ndarray,
+ position_weight: float = 1.0,
+ orientation_weight: float = 0.01,
+ ) -> np.ndarray:
"""
Compute inverse kinematics using placo solver.
diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py
index e1d4e0963..01bfcf544 100644
--- a/src/lerobot/motors/dynamixel/dynamixel.py
+++ b/src/lerobot/motors/dynamixel/dynamixel.py
@@ -60,7 +60,7 @@ class OperatingMode(Enum):
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
- # conveyer systems or a system that requires an additional reduction gear. Note that Max Position
+ # conveyor systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4
diff --git a/src/lerobot/motors/feetech/tables.py b/src/lerobot/motors/feetech/tables.py
index 48814957f..91e844a72 100644
--- a/src/lerobot/motors/feetech/tables.py
+++ b/src/lerobot/motors/feetech/tables.py
@@ -206,8 +206,12 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
+ "Goal_Position": 15,
"Goal_Velocity": 15,
+ "Goal_Speed": 15,
+ "Present_Position": 15,
"Present_Velocity": 15,
+ "Present_Speed": 15,
}
MODEL_ENCODING_TABLE = {
diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py
index 49f1e0f95..4cdc89ea9 100644
--- a/src/lerobot/policies/__init__.py
+++ b/src/lerobot/policies/__init__.py
@@ -14,6 +14,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
+from .groot.configuration_groot import GrootConfig as GrootConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
@@ -29,4 +30,5 @@ __all__ = [
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+ "GrootConfig",
]
diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py
index 4d2890ba6..b7cbcd061 100644
--- a/src/lerobot/policies/act/modeling_act.py
+++ b/src/lerobot/policies/act/modeling_act.py
@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
- decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
- encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
+ encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
+ decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py
index cfb550ab2..eb6266757 100644
--- a/src/lerobot/policies/factory.py
+++ b/src/lerobot/policies/factory.py
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
+from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -37,6 +38,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
+from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
@@ -101,6 +103,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
+ elif name == "groot":
+ from lerobot.policies.groot.modeling_groot import GrootPolicy
+
+ return GrootPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -142,6 +148,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
+ elif policy_type == "groot":
+ return GrootConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -199,6 +207,27 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
+ # TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
+ if isinstance(policy_cfg, GrootConfig):
+ # GROOT handles normalization in groot_pack_inputs_v3 step
+ # Need to override both stats AND normalize_min_max since saved config might be empty
+ preprocessor_overrides = {}
+ postprocessor_overrides = {}
+ preprocessor_overrides["groot_pack_inputs_v3"] = {
+ "stats": kwargs.get("dataset_stats"),
+ "normalize_min_max": True,
+ }
+
+ # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
+ env_action_dim = policy_cfg.output_features["action"].shape[0]
+ postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
+ "stats": kwargs.get("dataset_stats"),
+ "normalize_min_max": True,
+ "env_action_dim": env_action_dim,
+ }
+ kwargs["preprocessor_overrides"] = preprocessor_overrides
+ kwargs["postprocessor_overrides"] = postprocessor_overrides
+
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -293,6 +322,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
+ elif isinstance(policy_cfg, GrootConfig):
+ from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
+
+ processors = make_groot_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
@@ -303,6 +340,7 @@ def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
+ rename_map: dict[str, str] | None = None,
) -> PreTrainedPolicy:
"""
Instantiate a policy model.
@@ -319,6 +357,8 @@ def make_policy(
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
+ rename_map: Optional mapping of dataset or environment feature keys to match
+ expected policy feature names (e.g., `"left"` → `"camera1"`).
Returns:
An instantiated and device-placed policy model.
@@ -360,8 +400,10 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
- cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
- cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
+ if not cfg.output_features:
+ cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
+ if not cfg.input_features:
+ cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
@@ -378,4 +420,8 @@ def make_policy(
# policy = torch.compile(policy, mode="reduce-overhead")
+ if not rename_map:
+ validate_visual_features_consistency(cfg, features)
+ # TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
+
return policy
diff --git a/src/lerobot/policies/groot/README.md b/src/lerobot/policies/groot/README.md
new file mode 120000
index 000000000..ff4937f5c
--- /dev/null
+++ b/src/lerobot/policies/groot/README.md
@@ -0,0 +1 @@
+../../../../docs/source/policy_groot_README.md
\ No newline at end of file
diff --git a/src/lerobot/robots/stretch3/__init__.py b/src/lerobot/policies/groot/__init__.py
similarity index 65%
rename from src/lerobot/robots/stretch3/__init__.py
rename to src/lerobot/policies/groot/__init__.py
index b3070bbd6..c8933ff56 100644
--- a/src/lerobot/robots/stretch3/__init__.py
+++ b/src/lerobot/policies/groot/__init__.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
-# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 Nvidia and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,5 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .configuration_stretch3 import Stretch3RobotConfig
-from .robot_stretch3 import Stretch3Robot
+from .configuration_groot import GrootConfig
+from .modeling_groot import GrootPolicy
+from .processor_groot import make_groot_pre_post_processors
+
+__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
diff --git a/src/lerobot/robots/viperx/__init__.py b/src/lerobot/policies/groot/action_head/__init__.py
similarity index 70%
rename from src/lerobot/robots/viperx/__init__.py
rename to src/lerobot/policies/groot/action_head/__init__.py
index bfba07fc7..3159bfe65 100644
--- a/src/lerobot/robots/viperx/__init__.py
+++ b/src/lerobot/policies/groot/action_head/__init__.py
@@ -1,18 +1,14 @@
-#!/usr/bin/env python
-
-# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from .config_viperx import ViperXConfig
-from .viperx import ViperX
diff --git a/src/lerobot/policies/groot/action_head/action_encoder.py b/src/lerobot/policies/groot/action_head/action_encoder.py
new file mode 100644
index 000000000..c6fa0a779
--- /dev/null
+++ b/src/lerobot/policies/groot/action_head/action_encoder.py
@@ -0,0 +1,54 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+class SinusoidalPositionalEncoding(nn.Module):
+ """
+ Produces a sinusoidal encoding of shape (B, T, w)
+ given timesteps of shape (B, T).
+ """
+
+ def __init__(self, embedding_dim):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+
+ def forward(self, timesteps):
+ # timesteps: shape (B, T)
+ # We'll compute sin/cos frequencies across dim T
+ timesteps = timesteps.float() # ensure float
+
+ b, t = timesteps.shape
+ device = timesteps.device
+
+ half_dim = self.embedding_dim // 2
+ # typical log space frequencies for sinusoidal encoding
+ exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
+ torch.log(torch.tensor(10000.0)) / half_dim
+ )
+ # Expand timesteps to (B, T, 1) then multiply
+ freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
+
+ sin = torch.sin(freqs)
+ cos = torch.cos(freqs)
+ enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
+
+ return enc
diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py
new file mode 100755
index 000000000..40f7ba603
--- /dev/null
+++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py
@@ -0,0 +1,370 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from diffusers import ConfigMixin, ModelMixin
+from diffusers.configuration_utils import register_to_config
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.embeddings import (
+ SinusoidalPositionalEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from torch import nn
+
+
+class TimestepEncoder(nn.Module):
+ def __init__(self, embedding_dim, compute_dtype=torch.float32):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timesteps):
+ dtype = next(self.parameters()).dtype
+ timesteps_proj = self.time_proj(timesteps).to(dtype)
+ timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
+ return timesteps_emb
+
+
+class AdaLayerNorm(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
+ super().__init__()
+ self.chunk_dim = chunk_dim
+ output_dim = embedding_dim * 2
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ temb: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ temb = self.linear(self.silu(temb))
+ scale, shift = temb.chunk(2, dim=1)
+ x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
+ return x
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: int | None = None,
+ activation_fn: str = "geglu",
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: str | None = None,
+ num_positional_embeddings: int | None = None,
+ ff_inner_dim: int | None = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.dropout = dropout
+ self.cross_attention_dim = cross_attention_dim
+ self.activation_fn = activation_fn
+ self.attention_bias = attention_bias
+ self.norm_elementwise_affine = norm_elementwise_affine
+ self.positional_embeddings = positional_embeddings
+ self.num_positional_embeddings = num_positional_embeddings
+ self.norm_type = norm_type
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+ if final_dropout:
+ self.final_dropout = nn.Dropout(dropout)
+ else:
+ self.final_dropout = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ encoder_hidden_states: torch.Tensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ temb: torch.LongTensor | None = None,
+ ) -> torch.Tensor:
+ # 0. Self-Attention
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, temb)
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ # encoder_attention_mask=encoder_attention_mask,
+ )
+ if self.final_dropout:
+ attn_output = self.final_dropout(attn_output)
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+ return hidden_states
+
+
+class DiT(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 8,
+ attention_head_dim: int = 64,
+ output_dim: int = 26,
+ num_layers: int = 12,
+ dropout: float = 0.1,
+ attention_bias: bool = True,
+ activation_fn: str = "gelu-approximate",
+ num_embeds_ada_norm: int | None = 1000,
+ upcast_attention: bool = False,
+ norm_type: str = "ada_norm",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ max_num_positional_embeddings: int = 512,
+ compute_dtype=torch.float32,
+ final_dropout: bool = True,
+ positional_embeddings: str | None = "sinusoidal",
+ interleave_self_attention=False,
+ cross_attention_dim: int | None = None,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.gradient_checkpointing = False
+
+ # Timestep encoder
+ self.timestep_encoder = TimestepEncoder(
+ embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
+ )
+
+ all_blocks = []
+ for idx in range(self.config.num_layers):
+ use_self_attn = idx % 2 == 1 and interleave_self_attention
+ curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
+
+ all_blocks += [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ activation_fn=self.config.activation_fn,
+ attention_bias=self.config.attention_bias,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ positional_embeddings=positional_embeddings,
+ num_positional_embeddings=self.config.max_num_positional_embeddings,
+ final_dropout=final_dropout,
+ cross_attention_dim=curr_cross_attention_dim,
+ )
+ ]
+ self.transformer_blocks = nn.ModuleList(all_blocks)
+
+ # Output blocks
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
+ self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
+ print(
+ "Total number of DiT parameters: ",
+ sum(p.numel() for p in self.parameters() if p.requires_grad),
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, # Shape: (B, T, D)
+ encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
+ timestep: torch.LongTensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ return_all_hidden_states: bool = False,
+ ):
+ # Encode timesteps
+ temb = self.timestep_encoder(timestep)
+
+ # Process through transformer blocks - single pass through the blocks
+ hidden_states = hidden_states.contiguous()
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+
+ all_hidden_states = [hidden_states]
+
+ # Process through transformer blocks
+ for idx, block in enumerate(self.transformer_blocks):
+ if idx % 2 == 1 and self.config.interleave_self_attention:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ temb=temb,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=None,
+ temb=temb,
+ )
+ all_hidden_states.append(hidden_states)
+
+ # Output processing
+ conditioning = temb
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ if return_all_hidden_states:
+ return self.proj_out_2(hidden_states), all_hidden_states
+ else:
+ return self.proj_out_2(hidden_states)
+
+
+class SelfAttentionTransformer(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 8,
+ attention_head_dim: int = 64,
+ output_dim: int = 26,
+ num_layers: int = 12,
+ dropout: float = 0.1,
+ attention_bias: bool = True,
+ activation_fn: str = "gelu-approximate",
+ num_embeds_ada_norm: int | None = 1000,
+ upcast_attention: bool = False,
+ max_num_positional_embeddings: int = 512,
+ compute_dtype=torch.float32,
+ final_dropout: bool = True,
+ positional_embeddings: str | None = "sinusoidal",
+ interleave_self_attention=False,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.gradient_checkpointing = False
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ activation_fn=self.config.activation_fn,
+ attention_bias=self.config.attention_bias,
+ upcast_attention=self.config.upcast_attention,
+ positional_embeddings=positional_embeddings,
+ num_positional_embeddings=self.config.max_num_positional_embeddings,
+ final_dropout=final_dropout,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+ print(
+ "Total number of SelfAttentionTransformer parameters: ",
+ sum(p.numel() for p in self.parameters() if p.requires_grad),
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, # Shape: (B, T, D)
+ return_all_hidden_states: bool = False,
+ ):
+ # Process through transformer blocks - single pass through the blocks
+ hidden_states = hidden_states.contiguous()
+ all_hidden_states = [hidden_states]
+
+ # Process through transformer blocks
+ for _idx, block in enumerate(self.transformer_blocks):
+ hidden_states = block(hidden_states)
+ all_hidden_states.append(hidden_states)
+
+ if return_all_hidden_states:
+ return hidden_states, all_hidden_states
+ else:
+ return hidden_states
diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py
new file mode 100644
index 000000000..bfc456ba0
--- /dev/null
+++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py
@@ -0,0 +1,406 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from torch import nn
+from torch.distributions import Beta
+
+from lerobot.utils.import_utils import _transformers_available
+
+# Conditional import for type checking and lazy loading
+if TYPE_CHECKING or _transformers_available:
+ from transformers import PretrainedConfig
+ from transformers.feature_extraction_utils import BatchFeature
+else:
+ PretrainedConfig = object
+ BatchFeature = None
+
+from lerobot.policies.groot.action_head.action_encoder import (
+ SinusoidalPositionalEncoding,
+ swish,
+)
+
+from .cross_attention_dit import DiT, SelfAttentionTransformer
+
+
+class CategorySpecificLinear(nn.Module):
+ def __init__(self, num_categories, input_dim, hidden_dim):
+ super().__init__()
+ self.num_categories = num_categories
+ # For each category, we have separate weights and biases.
+ self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
+ self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
+
+ def forward(self, x, cat_ids):
+ selected_w = self.W[cat_ids]
+ selected_b = self.b[cat_ids]
+ return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
+
+
+class CategorySpecificMLP(nn.Module):
+ def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
+ super().__init__()
+ self.num_categories = num_categories
+ self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
+ self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
+
+ def forward(self, x, cat_ids):
+ hidden = F.relu(self.layer1(x, cat_ids))
+ return self.layer2(hidden, cat_ids)
+
+
+class MultiEmbodimentActionEncoder(nn.Module):
+ def __init__(self, action_dim, hidden_size, num_embodiments):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_embodiments = num_embodiments
+
+ # W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
+ self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
+ self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
+ self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
+ self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
+
+ def forward(self, actions, timesteps, cat_ids):
+ """
+ actions: shape (B, T, action_dim)
+ timesteps: shape (B,) -- a single scalar per batch item
+ cat_ids: shape (B,)
+ returns: shape (B, T, hidden_size)
+ """
+ b, t, _ = actions.shape
+
+ # 1) Expand each batch's single scalar time 'tau' across all T steps
+ # so that shape => (B, T)
+ # e.g. if timesteps is (B,), replicate across T
+ if timesteps.dim() == 1 and timesteps.shape[0] == b:
+ # shape (B,) => (B,T)
+ timesteps = timesteps.unsqueeze(1).expand(-1, t)
+ else:
+ raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
+
+ # 2) Standard action MLP step for shape => (B, T, w)
+ a_emb = self.W1(actions, cat_ids)
+
+ # 3) Get the sinusoidal encoding (B, T, w)
+ tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
+
+ # 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
+ x = torch.cat([a_emb, tau_emb], dim=-1)
+ x = swish(self.W2(x, cat_ids))
+
+ # 5) Finally W3 => (B, T, w)
+ x = self.W3(x, cat_ids)
+ return x
+
+
+@dataclass
+class FlowmatchingActionHeadConfig(PretrainedConfig):
+ """NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
+
+ add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
+ model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
+ diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
+ input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
+ backbone_embedding_dim: int = field(
+ default=1536, metadata={"help": "Backbone embedding channel dimension."}
+ )
+
+ hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
+ max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
+ action_dim: int = field(default=None, metadata={"help": "Action dimension."})
+ action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
+ noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
+ noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
+ noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
+ num_timestep_buckets: int = field(
+ default=1000, metadata={"help": "Number of timestep discretization buckets."}
+ )
+ num_inference_timesteps: int = field(
+ default=None,
+ metadata={"help": "Number of inference steps for noise diffusion."},
+ )
+ max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
+ tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
+ tune_diffusion_model: bool = field(
+ default=True, metadata={"help": "Whether to tune the diffusion model."}
+ )
+ load_pretrained_det_decode_layer_path: str = field(
+ default=None, metadata={"help": "Path to pretrained detection model."}
+ )
+ detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
+
+ freeze_decode_layer: bool = field(default=False)
+ expand_batch: int = field(default=None)
+ use_vlln: bool = field(default=True)
+
+ vl_self_attention_cfg: dict = field(default=None)
+ num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+
+class FlowmatchingActionHead(nn.Module):
+ config_class = FlowmatchingActionHeadConfig
+ supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ config: FlowmatchingActionHeadConfig,
+ ):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.input_embedding_dim = config.input_embedding_dim
+
+ self.model = DiT(**config.diffusion_model_cfg)
+ self.action_dim = config.action_dim
+ self.action_horizon = config.action_horizon
+ self.num_inference_timesteps = config.num_inference_timesteps
+
+ self.state_encoder = CategorySpecificMLP(
+ num_categories=config.max_num_embodiments,
+ input_dim=config.max_state_dim,
+ hidden_dim=self.hidden_size,
+ output_dim=self.input_embedding_dim,
+ )
+ self.action_encoder = MultiEmbodimentActionEncoder(
+ action_dim=config.action_dim,
+ hidden_size=self.input_embedding_dim,
+ num_embodiments=config.max_num_embodiments,
+ )
+ self.action_decoder = CategorySpecificMLP(
+ num_categories=config.max_num_embodiments,
+ input_dim=self.hidden_size,
+ hidden_dim=self.hidden_size,
+ output_dim=self.action_dim,
+ )
+ self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
+ nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
+
+ self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
+ self.vl_self_attention = (
+ SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
+ )
+
+ if config.add_pos_embed:
+ self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
+
+ self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
+ self.num_timestep_buckets = config.num_timestep_buckets
+ self.config = config
+ self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
+
+ def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
+ self.tune_projector = tune_projector
+ self.tune_diffusion_model = tune_diffusion_model
+ for p in self.parameters():
+ p.requires_grad = True
+ if not tune_projector:
+ self.state_encoder.requires_grad_(False)
+ self.action_encoder.requires_grad_(False)
+ self.action_decoder.requires_grad_(False)
+ if self.config.add_pos_embed:
+ self.position_embedding.requires_grad_(False)
+ if not tune_diffusion_model:
+ self.model.requires_grad_(False)
+ print(f"Tune action head projector: {self.tune_projector}")
+ print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
+ # Check if any parameters are still trainable. If not, print a warning.
+ if not tune_projector and not tune_diffusion_model:
+ for name, p in self.named_parameters():
+ if p.requires_grad:
+ print(f"Action head trainable parameter: {name}")
+ if not any(p.requires_grad for p in self.parameters()):
+ print("Warning: No action head trainable parameters found.")
+
+ def set_frozen_modules_to_eval_mode(self):
+ """
+ Huggingface will call model.train() at each training_step. To ensure
+ the expected behaviors for modules like dropout, batchnorm, etc., we
+ need to call model.eval() for the frozen modules.
+ """
+ if self.training:
+ if not self.tune_projector:
+ self.state_encoder.eval()
+ self.action_encoder.eval()
+ self.action_decoder.eval()
+ if self.config.add_pos_embed:
+ self.position_embedding.eval()
+ if not self.tune_diffusion_model:
+ self.model.eval()
+
+ def sample_time(self, batch_size, device, dtype):
+ sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
+ return (self.config.noise_s - sample) / self.config.noise_s
+
+ def prepare_input(self, batch: dict) -> BatchFeature:
+ return BatchFeature(data=batch)
+
+ def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
+ backbone_features = backbone_output["backbone_features"]
+ backbone_features = self.vlln(backbone_features)
+ backbone_features = self.vl_self_attention(backbone_features)
+ backbone_output["backbone_features"] = backbone_features
+ return backbone_output
+
+ def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
+ # Set frozen modules to eval
+ self.set_frozen_modules_to_eval_mode()
+
+ backbone_output = self.process_backbone_output(backbone_output)
+
+ if self.config.expand_batch is not None:
+ for k, v in backbone_output.items():
+ ndim = len(v.shape)
+ factors = [self.config.expand_batch]
+ while len(factors) < ndim:
+ factors.append(1)
+ factors = tuple(factors)
+ expanded = v.repeat(*factors)
+ backbone_output[k] = expanded
+
+ for k, v in action_input.items():
+ ndim = len(v.shape)
+ factors = [self.config.expand_batch]
+ while len(factors) < ndim:
+ factors.append(1)
+ factors = tuple(factors)
+ expanded = v.repeat(*factors)
+ action_input[k] = expanded
+
+ # Get vision and language embeddings.
+ vl_embs = backbone_output.backbone_features
+ device = vl_embs.device
+
+ # Get embodiment ID.
+ embodiment_id = action_input.embodiment_id
+
+ # Embed state.
+ state_features = self.state_encoder(action_input.state, embodiment_id)
+
+ # Embed noised action trajectory.
+ actions = action_input.action
+ noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
+ t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
+ t = t[:, None, None] # shape (B,1,1) for broadcast
+
+ noisy_trajectory = (1 - t) * noise + t * actions
+ velocity = actions - noise
+
+ # Convert (continuous) t -> discrete if needed
+ t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
+ action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
+
+ # Maybe add position embedding.
+ if self.config.add_pos_embed:
+ pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
+ pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
+ action_features = action_features + pos_embs
+
+ # Join vision, language, state and action embedding along sequence dimension.
+ future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
+ sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
+
+ vl_attn_mask = backbone_output.backbone_attention_mask
+
+ model_output = self.model(
+ hidden_states=sa_embs,
+ encoder_hidden_states=vl_embs,
+ encoder_attention_mask=vl_attn_mask,
+ timestep=t_discretized,
+ return_all_hidden_states=False, # NOTE (YL): not using flare now
+ )
+ pred = self.action_decoder(model_output, embodiment_id)
+ pred_actions = pred[:, -actions.shape[1] :]
+
+ # Slice out only the action portion of pred and target.
+ action_mask = action_input.action_mask
+ loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
+ loss = loss.sum() / action_mask.sum()
+ output_dict = {
+ "loss": loss,
+ }
+ return BatchFeature(data=output_dict)
+
+ @torch.no_grad()
+ def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
+ backbone_output = self.process_backbone_output(backbone_output)
+
+ # Get vision and language embeddings.
+ vl_embs = backbone_output.backbone_features
+ embodiment_id = action_input.embodiment_id
+
+ # Embed state.
+ state_features = self.state_encoder(action_input.state, embodiment_id)
+
+ # Set initial actions as the sampled noise.
+ batch_size = vl_embs.shape[0]
+ device = vl_embs.device
+ actions = torch.randn(
+ size=(batch_size, self.config.action_horizon, self.config.action_dim),
+ dtype=vl_embs.dtype,
+ device=device,
+ )
+
+ num_steps = self.num_inference_timesteps
+ dt = 1.0 / num_steps
+
+ # Run denoising steps.
+ for t in range(num_steps):
+ t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
+ t_discretized = int(t_cont * self.num_timestep_buckets)
+
+ # Embed noised action trajectory.
+ timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
+ action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
+ # Maybe add position embedding.
+ if self.config.add_pos_embed:
+ pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
+ pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
+ action_features = action_features + pos_embs
+
+ # Join vision, language, state and action embedding along sequence dimension.
+ future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
+ sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
+
+ # Run model forward.
+ model_output = self.model(
+ hidden_states=sa_embs,
+ encoder_hidden_states=vl_embs,
+ timestep=timesteps_tensor,
+ )
+ pred = self.action_decoder(model_output, embodiment_id)
+
+ pred_velocity = pred[:, -self.action_horizon :]
+
+ # Update actions using euler integration.
+ actions = actions + dt * pred_velocity
+ return BatchFeature(data={"action_pred": actions})
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+ @property
+ def dtype(self):
+ return next(iter(self.parameters())).dtype
diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py
new file mode 100644
index 000000000..8002c69ea
--- /dev/null
+++ b/src/lerobot/policies/groot/configuration_groot.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python
+
+# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.optim.optimizers import AdamWConfig
+from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+
+
+@PreTrainedConfig.register_subclass("groot")
+@dataclass
+class GrootConfig(PreTrainedConfig):
+ """Configuration for Groot policy wrapper."""
+
+ # Basic policy settings
+ n_obs_steps: int = 1
+ chunk_size: int = 50
+ n_action_steps: int = 50
+
+ # Dimension settings (must match pretrained GR00T model expectations)
+ # Maximum state dimension. Shorter states will be zero-padded.
+ max_state_dim: int = 64
+
+ # Maximum action dimension. Shorter actions will be zero-padded.
+ max_action_dim: int = 32
+
+ # Normalization (start with identity, adjust as needed)
+ normalization_mapping: dict[str, NormalizationMode] = field(
+ default_factory=lambda: {
+ "VISUAL": NormalizationMode.IDENTITY,
+ "STATE": NormalizationMode.MEAN_STD,
+ "ACTION": NormalizationMode.MEAN_STD,
+ }
+ )
+
+ # Image preprocessing (adjust to match Groot's expected input)
+ image_size: tuple[int, int] = (224, 224)
+
+ # Groot-specific model parameters (from groot_finetune_script.py)
+
+ # Path or HuggingFace model ID for the base Groot model
+ base_model_path: str = "nvidia/GR00T-N1.5-3B"
+
+ # HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
+ tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
+
+ # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
+ embodiment_tag: str = "new_embodiment"
+
+ # Fine-tuning control arguments
+
+ # Whether to fine-tune the llm backbone
+ tune_llm: bool = False
+
+ # Whether to fine-tune the vision tower
+ tune_visual: bool = False
+
+ # Whether to fine-tune the projector
+ tune_projector: bool = True
+
+ # Whether to fine-tune the diffusion model
+ tune_diffusion_model: bool = True
+
+ # LoRA parameters (from groot_finetune_script.py)
+ # Rank for the LORA model. If 0, no LORA will be used.
+ lora_rank: int = 0
+
+ # Alpha value for the LORA model
+ lora_alpha: int = 16
+
+ # Dropout rate for the LORA model
+ lora_dropout: float = 0.1
+
+ # Whether to use the full model for LORA
+ lora_full_model: bool = False
+
+ # Training parameters (matching groot_finetune_script.py)
+ optimizer_lr: float = 1e-4
+ optimizer_betas: tuple[float, float] = (0.95, 0.999)
+ optimizer_eps: float = 1e-8
+ optimizer_weight_decay: float = 1e-5
+ warmup_ratio: float = 0.05
+ use_bf16: bool = True
+
+ # Dataset parameters
+ # Video backend to use for training ('decord' or 'torchvision_av')
+ video_backend: str = "decord"
+
+ # Whether to balance dataset weights in mixture datasets
+ balance_dataset_weights: bool = True
+
+ # Whether to sample trajectories weighted by their length
+ balance_trajectory_weights: bool = True
+
+ # Optional dataset paths for delegating training to Isaac-GR00T runner
+ dataset_paths: list[str] | None = None
+ output_dir: str = "./tmp/gr00t"
+ save_steps: int = 1000
+ max_steps: int = 10000
+ batch_size: int = 32
+ dataloader_num_workers: int = 8
+ report_to: str = "wandb"
+ resume: bool = False
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ if self.n_action_steps > self.chunk_size:
+ raise ValueError(
+ f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
+ )
+
+ # groot_repo_path is now optional since we ported the components
+ # No validation needed
+
+ def validate_features(self) -> None:
+ """Validate and set up input/output features for Groot."""
+ image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
+ if not image_features:
+ raise ValueError(
+ "Groot policy requires at least one visual input feature. "
+ "No features of type FeatureType.VISUAL found in input_features."
+ )
+
+ if "observation.state" not in self.input_features:
+ state_feature = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(self.max_state_dim,),
+ )
+ self.input_features["observation.state"] = state_feature
+ else:
+ state_shape = self.input_features["observation.state"].shape
+ state_dim = state_shape[0] if state_shape else 0
+ if state_dim > self.max_state_dim:
+ raise ValueError(
+ f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
+ f"Either reduce state dimension or increase max_state_dim in config."
+ )
+
+ if "action" not in self.output_features:
+ action_feature = PolicyFeature(
+ type=FeatureType.ACTION,
+ shape=(self.max_action_dim,),
+ )
+ self.output_features["action"] = action_feature
+ else:
+ action_shape = self.output_features["action"].shape
+ action_dim = action_shape[0] if action_shape else 0
+ if action_dim > self.max_action_dim:
+ raise ValueError(
+ f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
+ f"Either reduce action dimension or increase max_action_dim in config."
+ )
+
+ def get_optimizer_preset(self) -> AdamWConfig:
+ """Return optimizer configuration."""
+ return AdamWConfig(
+ lr=self.optimizer_lr,
+ betas=self.optimizer_betas,
+ eps=self.optimizer_eps,
+ weight_decay=self.optimizer_weight_decay,
+ )
+
+ def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
+ """Return scheduler configuration."""
+ return CosineDecayWithWarmupSchedulerConfig(
+ num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
+ num_decay_steps=10000, # Adjust based on training steps
+ peak_lr=self.optimizer_lr,
+ decay_lr=self.optimizer_lr * 0.1,
+ )
+
+ @property
+ def observation_delta_indices(self) -> None:
+ """Return indices for delta observations (None for Groot)."""
+ return None
+
+ @property
+ def action_delta_indices(self) -> list[int]:
+ """Return indices for delta actions."""
+ return list(range(min(self.chunk_size, 16)))
+
+ @property
+ def reward_delta_indices(self) -> None:
+ """Return indices for delta rewards (None for Groot)."""
+ return None
diff --git a/src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py
new file mode 100755
index 000000000..526b4f7a2
--- /dev/null
+++ b/src/lerobot/policies/groot/eagle2_hg_model/configuration_eagle2_5_vl.py
@@ -0,0 +1,135 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
+from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class Eagle25VLConfig(PretrainedConfig):
+ model_type = "eagle_2_5_vl"
+ is_composition = True
+ sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ use_backbone_lora=0,
+ use_llm_lora=0,
+ pad2square=False,
+ select_layer=-4,
+ force_image_size=None,
+ downsample_ratio=0.5,
+ template=None,
+ dynamic_image_size=False,
+ use_thumbnail=False,
+ loss_version="v1",
+ min_dynamic_tiles=1,
+ max_dynamic_tiles=6,
+ mlp_checkpoint=False,
+ initializer_range=0.02,
+ _attn_implementation="flash_attention_2",
+ _attn_implementation_autoset=False,
+ llm_config=None,
+ image_token_index=None,
+ use_pixel_shuffle=True,
+ mlp_connector_layers=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if vision_config is None:
+ vision_config = {"model_type": "siglip_vision_model"}
+ logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
+
+ if text_config is None:
+ text_config = {"architectures": ["Qwen2ForCausalLM"]}
+ logger.info(
+ "text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
+ )
+
+ if vision_config["model_type"] == "siglip_vision_model":
+ self.vision_config = SiglipVisionConfig(**vision_config)
+ else:
+ raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
+
+ if text_config["architectures"][0] == "LlamaForCausalLM":
+ self.text_config = LlamaConfig(**text_config)
+ elif text_config["architectures"][0] == "Qwen2ForCausalLM":
+ self.text_config = Qwen2Config(**text_config)
+ elif text_config["architectures"][0] == "Qwen3ForCausalLM":
+ self.text_config = Qwen3Config(**text_config)
+ else:
+ raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
+ self.use_backbone_lora = use_backbone_lora
+ self.use_llm_lora = use_llm_lora
+ self.mlp_checkpoint = mlp_checkpoint
+ self.pad2square = pad2square
+ self.select_layer = select_layer
+ self.force_image_size = force_image_size
+ self.downsample_ratio = downsample_ratio
+ self.template = template
+ self.dynamic_image_size = dynamic_image_size
+ self.use_thumbnail = use_thumbnail
+ self.loss_version = loss_version
+ self.initializer_range = initializer_range
+ self.min_dynamic_tiles = min_dynamic_tiles
+ self.max_dynamic_tiles = max_dynamic_tiles
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
+ self._attn_implementation = _attn_implementation
+ self._attn_implementation_autoset = _attn_implementation_autoset
+ self.image_token_index = image_token_index
+ self.use_pixel_shuffle = use_pixel_shuffle
+ self.mlp_connector_layers = mlp_connector_layers
+ logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
+ logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["vision_config"] = self.vision_config.to_dict()
+ output["text_config"] = self.text_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ output["use_backbone_lora"] = self.use_backbone_lora
+ output["use_llm_lora"] = self.use_llm_lora
+ output["pad2square"] = self.pad2square
+ output["select_layer"] = self.select_layer
+ output["force_image_size"] = self.force_image_size
+ output["downsample_ratio"] = self.downsample_ratio
+ output["template"] = self.template
+ output["dynamic_image_size"] = self.dynamic_image_size
+ output["use_thumbnail"] = self.use_thumbnail
+ output["min_dynamic_tiles"] = self.min_dynamic_tiles
+ output["max_dynamic_tiles"] = self.max_dynamic_tiles
+ output["tie_word_embeddings"] = self.tie_word_embeddings
+ output["_attn_implementation"] = self._attn_implementation
+ output["_attn_implementation_autoset"] = self._attn_implementation_autoset
+ output["use_pixel_shuffle"] = self.use_pixel_shuffle
+ output["mlp_connector_layers"] = self.mlp_connector_layers
+ return output
diff --git a/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py
new file mode 100644
index 000000000..6b4f6d7ac
--- /dev/null
+++ b/src/lerobot/policies/groot/eagle2_hg_model/image_processing_eagle2_5_vl_fast.py
@@ -0,0 +1,504 @@
+# --------------------------------------------------------
+# NVIDIA
+# Copyright (c) 2025 NVIDIA
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+
+# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
+from typing import Optional
+
+from transformers.image_processing_utils import (
+ BatchFeature,
+ get_patch_output_size,
+)
+from transformers.image_processing_utils_fast import (
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from transformers.image_utils import (
+ IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
+ IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+ get_image_size,
+ make_flat_list_of_images,
+ validate_kwargs,
+)
+from transformers.processing_utils import Unpack
+from transformers.utils import (
+ TensorType,
+ add_start_docstrings,
+ is_torch_available,
+ is_torchvision_v2_available,
+)
+from transformers.video_utils import VideoInput
+
+if is_torch_available():
+ import torch
+if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F # noqa: N812
+ from transformers.image_utils import pil_torch_interpolation_mapping
+else:
+ from torchvision.transforms import functional as F # noqa: N812
+
+
+def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
+ """Crop the given numpy array.
+
+ Args:
+ img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
+ left (int): The left coordinate of the crop box.
+ top (int): The top coordinate of the crop box.
+ right (int): The right coordinate of the crop box.
+ bottom (int): The bottom coordinate of the crop box.
+
+ Returns:
+ torch.Tensor: Cropped image.
+ """
+ if not isinstance(img, torch.Tensor):
+ raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
+
+ if img.ndim not in [2, 3]:
+ raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
+
+ img_height = img.shape[1]
+ img_width = img.shape[2]
+ if top < 0 or left < 0 or bottom > img_height or right > img_width:
+ raise ValueError("Crop coordinates out of bounds")
+
+ if top >= bottom or left >= right:
+ raise ValueError("Invalid crop coordinates")
+
+ return img[:, top:bottom, left:right]
+
+
+class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ max_dynamic_tiles: int | None
+ min_dynamic_tiles: int | None
+ use_thumbnail: bool | None
+ pad_during_tiling: bool | None
+ do_pad: bool | None
+
+
+@add_start_docstrings(
+ "Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
+ # BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
+ """
+ image_grid_pinpoints (`List[List[int]]`, *optional*):
+ A list of possible resolutions to use for processing high resolution images. The best resolution is selected
+ based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
+ method. Not used for processing videos.
+ do_pad (`bool`, *optional*):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ """,
+)
+class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BICUBIC
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 448, "width": 448}
+ default_to_square = False
+ crop_size = None
+ do_resize = True
+ do_center_crop = None
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ do_pad = True
+ max_dynamic_tiles = 12
+ min_dynamic_tiles = 1
+ use_thumbnail = True
+ pad_during_tiling = False
+ valid_kwargs = Eagle25VLFastImageProcessorKwargs
+ model_input_names = ["pixel_values_videos"]
+
+ def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ @add_start_docstrings(
+ # BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
+ """
+ max_dynamic_tiles (`int`, *optional*):
+ The maximum number of dynamic tiles to use for processing high resolution images.
+ min_dynamic_tiles (`int`, *optional*):
+ The minimum number of dynamic tiles to use for processing high resolution images.
+ use_thumbnail (`bool`, *optional*):
+ Whether to use a thumbnail for processing high resolution images.
+ pad_during_tiling (`bool`, *optional*):
+ Whether to pad the image during tiling.
+ do_pad (`bool`, *optional*):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ """,
+ )
+
+ # NOTE(YL): we will overload the preprocess method to add the image_flags
+ # def preprocess(
+ # self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
+ # ) -> BatchFeature:
+ # return super().preprocess(images, **kwargs)
+
+ def _prepare_images_structure(
+ self,
+ images: ImageInput,
+ expected_ndims: int = 3,
+ ) -> ImageInput:
+ """
+ Prepare the images structure for processing.
+
+ Args:
+ images (`ImageInput`):
+ The input images to process.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
+
+ Returns:
+ `ImageInput`: The images with a valid nesting.
+ """
+ return make_flat_list_of_images(images)
+
+ def _resize_for_patching(
+ self,
+ image: "torch.Tensor",
+ target_resolution: tuple,
+ interpolation: "F.InterpolationMode",
+ input_data_format: ChannelDimension,
+ ) -> "torch.Tensor":
+ """
+ Resizes an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image ("torch.Tensor"):
+ The input image.
+ target_resolution (tuple):
+ The target resolution (height, width) of the image.
+ interpolation (`InterpolationMode`):
+ Resampling filter to use if resizing the image.
+ input_data_format (`ChannelDimension` or `str`):
+ The channel dimension format of the input image.
+
+ Returns:
+ "torch.Tensor": The resized and padded image.
+ """
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ # Resize the image
+ resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+
+ return resized_image
+
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
+ """
+ previous version mainly focus on ratio.
+ We also consider area ratio here.
+ """
+ best_factor = float("-inf")
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ # ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ # area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
+ """
+ new area > 60% of original image area is enough.
+ """
+ factor_based_on_area_n_ratio = min(
+ (ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
+ ) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
+
+ if factor_based_on_area_n_ratio > best_factor:
+ best_factor = factor_based_on_area_n_ratio
+ best_ratio = ratio
+
+ return best_ratio
+
+ def _pad_for_patching(
+ self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
+ ) -> "torch.Tensor":
+ """
+ Pad an image to a target resolution while maintaining aspect ratio.
+ """
+ target_height, target_width = target_resolution
+ new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
+
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+
+ padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
+
+ return padded_image
+
+ def _get_image_patches(
+ self,
+ image: "torch.Tensor",
+ min_num: int,
+ max_num: int,
+ size: tuple,
+ tile_size: int,
+ use_thumbnail: bool,
+ interpolation: "F.InterpolationMode",
+ pad_during_tiling: bool,
+ ) -> list["torch.Tensor"]:
+ image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
+ orig_height, orig_width = image_size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = {
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ }
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = self.find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, tile_size
+ )
+
+ # calculate the target width and height
+ target_width = tile_size * target_aspect_ratio[0]
+ target_height = tile_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+ if pad_during_tiling:
+ resized_image = self._resize_for_patching(
+ image,
+ (target_height, target_width),
+ interpolation=interpolation,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ padded_image = self._pad_for_patching(
+ resized_image,
+ (target_height, target_width),
+ input_data_format=ChannelDimension.FIRST,
+ )
+ image_used_to_split = padded_image
+ else:
+ image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
+
+ processed_tiles = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // tile_size)) * tile_size,
+ (i // (target_width // tile_size)) * tile_size,
+ ((i % (target_width // tile_size)) + 1) * tile_size,
+ ((i // (target_width // tile_size)) + 1) * tile_size,
+ )
+ # split the image
+ split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
+ processed_tiles.append(split_img)
+ assert len(processed_tiles) == blocks
+
+ if use_thumbnail and len(processed_tiles) != 1:
+ thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
+ processed_tiles.append(thumbnail_img)
+
+ return processed_tiles
+
+ def _pad_for_batching(
+ self,
+ pixel_values: list["torch.Tensor"],
+ ) -> list["torch.Tensor"]:
+ """
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
+
+ Args:
+ pixel_values (`List[torch.Tensor]`):
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
+
+ Returns:
+ List[`torch.Tensor`]: The padded images.
+ """
+ max_patch = max(len(x) for x in pixel_values)
+ pixel_values = [
+ torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
+ for image in pixel_values
+ ]
+
+ return pixel_values
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ max_dynamic_tiles: int,
+ min_dynamic_tiles: int,
+ use_thumbnail: bool,
+ pad_during_tiling: bool,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: float | list[float] | None,
+ image_std: float | list[float] | None,
+ do_pad: bool,
+ return_tensors: str | TensorType | None,
+ pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
+ disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
+ ) -> BatchFeature:
+ processed_images = []
+ image_sizes = []
+ # Determine the size tuple
+ if size and size.height and size.width:
+ size_tuple = (size.height, size.width)
+ else:
+ size_tuple = (size.shortest_edge, size.shortest_edge)
+
+ # Determine the patch size
+ if crop_size and crop_size.height:
+ tile_size = crop_size.height
+ elif size and size.height:
+ tile_size = size.height
+ else:
+ tile_size = size.shortest_edge
+
+ for image in images:
+ image_patches = self._get_image_patches(
+ image,
+ min_num=min_dynamic_tiles,
+ max_num=max_dynamic_tiles,
+ size=size_tuple,
+ tile_size=tile_size,
+ use_thumbnail=use_thumbnail,
+ interpolation=interpolation,
+ pad_during_tiling=pad_during_tiling,
+ )
+
+ # Group images by size for batched processing
+ processed_image_patches_grouped = {}
+ # Added for transformers >=4.53.0 compatibility
+ grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
+ image_patches,
+ disable_grouping=disable_grouping,
+ )
+
+ for shape, stacked_image_patches in grouped_image_patches.items():
+ if do_resize:
+ stacked_image_patches = self.resize(
+ image=stacked_image_patches,
+ size=size,
+ interpolation=interpolation,
+ )
+ if do_center_crop:
+ stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
+ # Fused rescale and normalize
+ stacked_image_patches = self.rescale_and_normalize(
+ stacked_image_patches,
+ do_rescale,
+ rescale_factor,
+ do_normalize,
+ image_mean,
+ image_std,
+ )
+ processed_image_patches_grouped[shape] = stacked_image_patches
+ processed_image_patches = reorder_images(
+ processed_image_patches_grouped, grouped_image_patches_index
+ )
+ processed_image_patches = (
+ torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
+ )
+ processed_images.append(processed_image_patches)
+ image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
+
+ if do_pad:
+ processed_images = self._pad_for_batching(processed_images)
+
+ # processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+ processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
+ return BatchFeature(
+ data={"pixel_values": processed_images, "image_sizes": image_sizes},
+ tensor_type=return_tensors,
+ )
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ videos: VideoInput = None,
+ **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
+ ) -> BatchFeature:
+ validate_kwargs(
+ captured_kwargs=kwargs.keys(),
+ valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
+ )
+ # Set default kwargs from self. This ensures that if a kwarg is not provided
+ # by the user, it gets its default value from the instance, or is set to None.
+ for kwarg_name in self.valid_kwargs.__annotations__:
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
+
+ # Extract parameters that are only used for preparing the input images
+ do_convert_rgb = kwargs.pop("do_convert_rgb")
+ input_data_format = kwargs.pop("input_data_format")
+ device = kwargs.pop("device")
+ # Prepare input images
+ # transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
+ if images is not None:
+ images = self._prepare_image_like_inputs(
+ images=images,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ device=device,
+ )
+
+ if videos is not None:
+ videos = self._prepare_image_like_inputs(
+ images=videos,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ device=device,
+ )
+
+ # Update kwargs that need further processing before being validated
+ kwargs = self._further_process_kwargs(**kwargs)
+
+ # Validate kwargs
+ self._validate_preprocess_kwargs(**kwargs)
+
+ # torch resize uses interpolation instead of resample
+ # Added for transformers >=4.53.0 compatibility
+ resample = kwargs.pop("resample", self.resample)
+ kwargs["interpolation"] = (
+ pil_torch_interpolation_mapping[resample]
+ if isinstance(resample, PILImageResampling | int)
+ else resample
+ )
+
+ # Filter kwargs to only include those accepted by _preprocess
+ valid_preprocess_kwargs = {
+ "do_resize",
+ "size",
+ "max_dynamic_tiles",
+ "min_dynamic_tiles",
+ "use_thumbnail",
+ "pad_during_tiling",
+ "interpolation",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_pad",
+ "return_tensors",
+ "pad_size",
+ "disable_grouping",
+ }
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
+ if images is not None:
+ return self._preprocess(images, **filtered_kwargs)
+ elif videos is not None:
+ return self._preprocess(videos, **filtered_kwargs)
+
+
+__all__ = ["Eagle25VLImageProcessorFast"]
diff --git a/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py
new file mode 100755
index 000000000..5a66cfbce
--- /dev/null
+++ b/src/lerobot/policies/groot/eagle2_hg_model/modeling_eagle2_5_vl.py
@@ -0,0 +1,395 @@
+# --------------------------------------------------------
+# NVIDIA
+# Copyright (c) 2025 NVIDIA
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import inspect
+
+import torch
+import torch.utils.checkpoint as cp
+from peft import LoraConfig, get_peft_model
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers import GenerationConfig
+from transformers.generation import GenerationMixin
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
+from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
+from transformers.models.siglip.modeling_siglip import SiglipVisionModel
+from transformers.utils import add_start_docstrings, logging
+
+from .configuration_eagle2_5_vl import Eagle25VLConfig
+
+logger = logging.get_logger(__name__)
+
+
+# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
+EAGLE2_5_VL_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Eagle25VLConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
+ EAGLE2_5_VL_START_DOCSTRING,
+)
+class Eagle25VLPreTrainedModel(PreTrainedModel):
+ config_class = Eagle25VLConfig
+ base_model_prefix = "model"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+ _no_split_modules = [
+ "Qwen2DecoderLayer",
+ "LlamaDecoderLayer",
+ "Siglip2EncoderLayer",
+ "SiglipEncoderLayer",
+ ]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_quantized_cache = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear | nn.Conv2d):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
+ config_class = Eagle25VLConfig
+
+ def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
+ super().__init__(config)
+
+ image_size = config.force_image_size or config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.patch_size = patch_size
+ if config.use_pixel_shuffle:
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
+ else:
+ self.num_image_token = int((image_size // patch_size) ** 2)
+
+ self.select_layer = config.select_layer
+ self.downsample_ratio = config.downsample_ratio
+ self.loss_version = config.loss_version
+ self.mlp_checkpoint = config.mlp_checkpoint
+ self.use_pixel_shuffle = config.use_pixel_shuffle
+ self.mlp_connector_layers = config.mlp_connector_layers
+ logger.info(f"num_image_token: {self.num_image_token}")
+ logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
+ if vision_model is not None:
+ self.vision_model = vision_model
+ else:
+ if config.vision_config.model_type == "siglip_vision_model":
+ config.vision_config._attn_implementation = "flash_attention_2"
+ self.vision_model = SiglipVisionModel(config.vision_config)
+ else:
+ raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
+
+ if language_model is not None:
+ self.language_model = language_model
+ else:
+ if config.text_config.architectures[0] == "LlamaForCausalLM":
+ self.language_model = LlamaForCausalLM(config.text_config)
+ elif config.text_config.architectures[0] == "Phi3ForCausalLM":
+ raise NotImplementedError("Phi3 is not implemented.")
+ # self.language_model = Phi3ForCausalLM(config.text_config)
+ elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
+ assert config.text_config._attn_implementation == "flash_attention_2", (
+ f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
+ )
+ self.language_model = Qwen2ForCausalLM(config.text_config)
+ elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
+ self.language_model = Qwen3ForCausalLM(config.text_config)
+ else:
+ raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
+
+ vit_hidden_size = config.vision_config.hidden_size
+ llm_hidden_size = config.text_config.hidden_size
+
+ if config.mlp_connector_layers == 2:
+ self.mlp1 = nn.Sequential(
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
+ nn.GELU(),
+ nn.Linear(llm_hidden_size, llm_hidden_size),
+ )
+ elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
+ self.mlp1 = nn.Sequential(
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
+ )
+ elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
+ self.mlp1 = nn.Sequential(
+ nn.Linear(vit_hidden_size, llm_hidden_size),
+ )
+ else:
+ raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
+
+ self.image_token_index = config.image_token_index
+ self.neftune_alpha = None
+
+ if config.use_backbone_lora:
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
+
+ self.use_llm_lora = config.use_llm_lora
+ if config.use_llm_lora:
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
+
+ self.check_forward_kwargs()
+
+ def check_forward_kwargs(self):
+ # We intentionally avoid using **kwargs in forward because Hugging Face Transformers
+ # has special handling for functions with **kwargs parameters that would affect
+ # how our model is processed during training and inference.
+ forward_params = inspect.signature(self.forward).parameters
+ assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
+
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
+ lora_config = LoraConfig(
+ r=r,
+ target_modules=[
+ "self_attn.q_proj",
+ "self_attn.k_proj",
+ "self_attn.v_proj",
+ "self_attn.out_proj",
+ "mlp.fc1",
+ "mlp.fc2",
+ ],
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
+ self.vision_model.print_trainable_parameters()
+
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
+ lora_config = LoraConfig(
+ r=r,
+ target_modules=[
+ "self_attn.q_proj",
+ "self_attn.k_proj",
+ "self_attn.v_proj",
+ "self_attn.o_proj",
+ "mlp.gate_proj",
+ "mlp.down_proj",
+ "mlp.up_proj",
+ ],
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ task_type="CAUSAL_LM",
+ )
+ self.language_model = get_peft_model(self.language_model, lora_config)
+ self.language_model.enable_input_require_grads()
+ self.language_model.print_trainable_parameters()
+ self.use_llm_lora = True
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ image_flags: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ num_tiles_list: list[torch.Tensor] | None = None,
+ ) -> tuple | CausalLMOutputWithPast:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
+
+ vit_embeds = self.extract_feature(pixel_values)
+
+ if image_flags is not None:
+ image_flags = image_flags.view(-1)
+ vit_embeds = vit_embeds[image_flags == 1]
+
+ b, n, c = input_embeds.shape
+ input_embeds = input_embeds.reshape(b * n, c)
+
+ input_ids = input_ids.reshape(b * n)
+ selected = input_ids == self.image_token_index
+ try:
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
+ except Exception as e:
+ vit_embeds = vit_embeds.reshape(-1, c)
+ print(
+ f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
+ f"vit_embeds.shape={vit_embeds.shape}"
+ )
+ n_token = selected.sum()
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
+
+ input_embeds = input_embeds.reshape(b, n, c)
+
+ outputs = self.language_model(
+ inputs_embeds=input_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ logits = outputs.logits
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def pixel_shuffle(self, x, scale_factor=0.5):
+ n, w, h, c = x.size()
+ # N, W, H, C --> N, W, H * scale, C // scale
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
+ x = x.permute(0, 2, 1, 3).contiguous()
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
+
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+ def extract_feature(self, pixel_values):
+ if self.select_layer == -1:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True
+ )
+ if hasattr(vit_embeds, "last_hidden_state"):
+ vit_embeds = vit_embeds.last_hidden_state
+
+ else:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values, output_hidden_states=True, return_dict=True
+ ).hidden_states[self.select_layer]
+
+ if self.use_pixel_shuffle:
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.pixel_shuffle(
+ vit_embeds, scale_factor=self.downsample_ratio
+ ) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
+ vit_embeds = vit_embeds.reshape(
+ vit_embeds.shape[0], -1, vit_embeds.shape[-1]
+ ) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
+
+ if self.mlp_checkpoint and vit_embeds.requires_grad:
+ vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
+ else:
+ vit_embeds = self.mlp1(vit_embeds)
+
+ return vit_embeds
+
+ @torch.no_grad()
+ def generate(
+ self,
+ pixel_values: torch.FloatTensor | None = None,
+ input_ids: torch.FloatTensor | None = None,
+ attention_mask: torch.LongTensor | None = None,
+ visual_features: torch.FloatTensor | None = None,
+ generation_config: GenerationConfig | None = None,
+ output_hidden_states: bool | None = None,
+ image_sizes: list[tuple[int, int]] | None = None,
+ **generate_kwargs,
+ ) -> torch.LongTensor:
+ if pixel_values is not None:
+ if visual_features is not None:
+ vit_embeds = visual_features
+ else:
+ vit_embeds = self.extract_feature(pixel_values)
+
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
+ b, n, c = input_embeds.shape
+ input_embeds = input_embeds.reshape(b * n, c)
+
+ input_ids = input_ids.reshape(b * n)
+ selected = input_ids == self.config.image_token_index
+ assert selected.sum() != 0
+ input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
+
+ input_embeds = input_embeds.reshape(b, n, c)
+ else:
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
+
+ if "use_cache" not in generate_kwargs:
+ generate_kwargs["use_cache"] = True
+
+ outputs = self.language_model.generate(
+ inputs_embeds=input_embeds,
+ attention_mask=attention_mask,
+ generation_config=generation_config,
+ output_hidden_states=output_hidden_states,
+ **generate_kwargs,
+ )
+
+ return outputs
+
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
+ def get_decoder(self):
+ return self.language_model.get_decoder()
diff --git a/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py
new file mode 100755
index 000000000..27f9b3345
--- /dev/null
+++ b/src/lerobot/policies/groot/eagle2_hg_model/processing_eagle2_5_vl.py
@@ -0,0 +1,518 @@
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Eagle25VL.
+copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
+"""
+
+import base64
+import os
+import re
+from io import BytesIO
+
+import requests
+import torch
+from PIL import Image
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.utils import logging
+from transformers.video_utils import VideoInput
+
+logger = logging.get_logger(__name__)
+
+
+FRAME_FACTOR = 2
+FPS = 2.0
+FPS_MIN_FRAMES = 4
+FPS_MAX_FRAMES = 256
+
+
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == "RGBA":
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
+
+
+def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
+ image = ele["image"] if "image" in ele else ele["image_url"]
+ image_obj = None
+ if isinstance(image, Image.Image):
+ image_obj = image
+ elif image.startswith("http://") or image.startswith("https://"):
+ response = requests.get(image, stream=True, timeout=10)
+ image_obj = Image.open(BytesIO(response.content))
+ elif image.startswith("file://"):
+ image_obj = Image.open(image[7:])
+ elif image.startswith("data:image"):
+ if "base64," in image:
+ _, base64_data = image.split("base64,", 1)
+ data = base64.b64decode(base64_data)
+ image_obj = Image.open(BytesIO(data))
+ else:
+ image_obj = Image.open(image)
+ if image_obj is None:
+ raise ValueError(
+ f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
+ )
+ image = to_rgb(image_obj)
+ if "scale_factor" in ele:
+ scale_factor = ele["scale_factor"]
+ image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
+ return image
+
+
+class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
+ # see processing_utils.ProcessingKwargs documentation for usage.
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ "images_kwargs": {},
+ "videos_kwargs": {"max_dynamic_tiles": 1},
+ }
+
+
+class Eagle25VLProcessor(ProcessorMixin):
+ r"""
+ Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
+
+ [`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
+ [`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ num_image_tokens (`int`, *optional*):
+ Number of image tokens for one imagethat will be returned by vision tower.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Should be same as in model's config
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ image_token (`str`, *optional*, defaults to `""`):
+ Special token used to denote image location.
+ video_token (`str`, *optional*, defaults to `"