mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f9b8f297b4 | |||
| 95527f6051 | |||
| 407ee867b9 | |||
| 26ff40ddd7 | |||
| a5e6409985 | |||
| 6d269b28c8 | |||
| b607c8458e | |||
| 9e83510c99 | |||
| 1c9fbba9a9 | |||
| 6a1b5ceb9d | |||
| daa4c4dd30 | |||
| 1f7b03f5f2 | |||
| ff992a7a1d | |||
| cb8edf17e6 | |||
| 5699f6cbf4 | |||
| 48269dddb3 | |||
| 8df8d3d866 | |||
| 0e6114ac36 | |||
| c8ce413d73 | |||
| 82dffde7fa | |||
| eaf0218bc8 | |||
| a0e52d52fe | |||
| e99c55af4b | |||
| 408e0ca763 | |||
| ce24063efd | |||
| 82934719db | |||
| 401a217597 | |||
| 40094b0464 | |||
| fdbfc015a2 | |||
| d656da8ccc | |||
| b5f65e5332 | |||
| cd6b43ea7a | |||
| 2236bbe7a3 | |||
| cb0a944941 | |||
| 8a3d64033f | |||
| 03ee50e08f |
@@ -382,6 +382,7 @@ jobs:
|
||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||
--env.type=robotwin \
|
||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -482,6 +483,7 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -693,6 +695,7 @@ jobs:
|
||||
--env.task=\"\$ROBOMME_TASKS\" \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -800,6 +803,7 @@ jobs:
|
||||
--env.type=libero_plus \
|
||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -900,6 +904,8 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--env.episode_length=50 \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
@@ -152,13 +152,14 @@ jobs:
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||
uv pip install \
|
||||
--torch-backend cpu \
|
||||
--index-url https://test.pypi.org/simple/ \
|
||||
--extra-index-url https://pypi.org/simple \
|
||||
--index-strategy unsafe-best-match \
|
||||
"lerobot[all]==$BASE_VERSION"
|
||||
else
|
||||
echo "Installing release version $VERSION from PyPI..."
|
||||
uv pip install "lerobot[all]==$VERSION"
|
||||
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
|
||||
fi
|
||||
- name: Check lerobot version
|
||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||
|
||||
@@ -19,19 +19,19 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Runs at 02:00
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
# schedule:
|
||||
# - cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
This issue was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
This PR was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 30
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
days-before-pr-close: 30
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
|
||||
|
||||
All policies typically train for **5–10 epochs** (see §7).
|
||||
|
||||
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
|
||||
|
||||
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
|
||||
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
|
||||
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -109,7 +109,7 @@ lerobot-train \
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
|
||||
|
||||
## Inference & Evaluation
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ USER root
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
||||
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
|
||||
@@ -18,9 +18,8 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
# TODO(Steven): Bump these versions
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
@@ -36,16 +35,13 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install Python, system dependencies, and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
build-essential git curl \
|
||||
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -24,6 +24,12 @@
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: hardware_guide
|
||||
title: Compute Hardware Guide
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Compute & Hardware"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
@@ -47,6 +53,10 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: evo1
|
||||
title: EVO1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
@@ -140,10 +150,6 @@
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -1,60 +1,37 @@
|
||||
# Bring Your Own Policies
|
||||
# Adding a Policy
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
|
||||
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
|
||||
|
||||
### Package Structure
|
||||
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
---
|
||||
|
||||
### Package Configuration
|
||||
## Anatomy of a policy
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
### Configuration class
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
# configuration_my_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig
|
||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@PreTrainedConfig.register_subclass("my_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
class MyPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
raise ValueError("n_action_steps cannot exceed horizon")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
"""Validate input/output feature compatibility.
|
||||
|
||||
Call this explicitly from your policy's __init__ — the base class does not.
|
||||
"""
|
||||
if not self.image_features:
|
||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
||||
raise ValueError("MyPolicy requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
||||
raise ValueError("MyPolicy requires 'action' in output_features.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
||||
"""
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns."""
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
return None
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
|
||||
|
||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
||||
### Policy class
|
||||
|
||||
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
# modeling_my_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_custom_policy"
|
||||
class MyPolicy(PreTrainedPolicy):
|
||||
config_class = MyPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features() # not called automatically by the base class
|
||||
self.config = config
|
||||
self.model = ... # your nn.Module here
|
||||
|
||||
def reset(self):
|
||||
"""Reset episode state."""
|
||||
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
|
||||
...
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
...
|
||||
|
||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Return a single action for the current timestep (called at inference)."""
|
||||
"""Return a single action for the current timestep (called every step at inference)."""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
|
||||
"""Compute the training loss.
|
||||
|
||||
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
|
||||
logging-friendly Python natives (no tensors with gradients).
|
||||
|
||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||
timesteps padded because the episode ended before `horizon` steps, you
|
||||
timesteps padded because the episode ended before `horizon` steps; you
|
||||
can exclude those from your loss.
|
||||
"""
|
||||
actions = batch[ACTION]
|
||||
action_is_pad = batch.get("action_is_pad")
|
||||
...
|
||||
return {"loss": ...}
|
||||
return loss, {"some_loss_component": some_loss_component.item()}
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
The methods called by the train/eval loops:
|
||||
|
||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
| Method | Used by | What it does |
|
||||
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
|
||||
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
|
||||
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
|
||||
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
|
||||
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
|
||||
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
|
||||
|
||||
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
|
||||
|
||||
### Processor functions
|
||||
|
||||
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
# processor_my_policy.py
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
def make_my_policy_pre_post_processors(
|
||||
config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
**Important — function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
|
||||
## Step 5: Package Initialization
|
||||
---
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
## Path A: Out-of-tree plugin
|
||||
|
||||
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
|
||||
|
||||
### Package structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_policy.py
|
||||
├── modeling_my_policy.py
|
||||
└── processor_my_policy.py
|
||||
```
|
||||
|
||||
### `pyproject.toml`
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
### Package `__init__.py`
|
||||
|
||||
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
@@ -204,44 +239,148 @@ except ImportError:
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
from .modeling_my_policy import MyPolicy
|
||||
from .processor_my_policy import make_my_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
"MyPolicyConfig",
|
||||
"MyPolicy",
|
||||
"make_my_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
### Install and use
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
cd lerobot_policy_my_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
pip install lerobot_policy_my_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--policy.type my_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
---
|
||||
|
||||
## Path B: Contributing in-tree
|
||||
|
||||
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
|
||||
|
||||
### In-tree layout
|
||||
|
||||
```
|
||||
src/lerobot/policies/my_policy/
|
||||
├── __init__.py # re-exports config + modeling + processor factory
|
||||
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
|
||||
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
|
||||
├── processor_my_policy.py # make_my_policy_pre_post_processors
|
||||
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
|
||||
```
|
||||
|
||||
Two notes:
|
||||
|
||||
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
|
||||
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
|
||||
|
||||
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
|
||||
|
||||
### Wiring
|
||||
|
||||
Three places need to know about your policy. All by name.
|
||||
|
||||
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
|
||||
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
|
||||
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
|
||||
|
||||
Mirror an existing policy that's structurally similar to yours; the diff is small.
|
||||
|
||||
### Heavy / optional dependencies
|
||||
|
||||
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
|
||||
|
||||
```python
|
||||
from typing import TYPE_CHECKING
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
else:
|
||||
DDIMScheduler = None # keeps the symbol bindable at import time
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(self, config):
|
||||
require_package("diffusers", extra="diffusion")
|
||||
super().__init__(config)
|
||||
...
|
||||
```
|
||||
|
||||
This way:
|
||||
|
||||
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
|
||||
- Type checkers see the real symbol.
|
||||
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
|
||||
|
||||
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
|
||||
|
||||
### Benchmarks and a published checkpoint
|
||||
|
||||
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
|
||||
|
||||
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
|
||||
|
||||
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
|
||||
|
||||
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
|
||||
|
||||
```markdown
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with `lerobot/<policy>_libero`:
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| -------------- | -----------: | ---------: |
|
||||
| libero_spatial | 87.5% | 50 |
|
||||
| libero_object | 93.0% | 50 |
|
||||
| libero_goal | 81.5% | 50 |
|
||||
| libero_10 | 62.0% | 50 |
|
||||
| **average** | **81.0%** | 200 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
|
||||
```
|
||||
|
||||
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
|
||||
|
||||
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
|
||||
|
||||
### PR checklist
|
||||
|
||||
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
|
||||
|
||||
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
|
||||
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
|
||||
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
|
||||
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
|
||||
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
|
||||
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
|
||||
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
|
||||
|
||||
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
|
||||
|
||||
---
|
||||
|
||||
## Examples and community contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) — Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
# EO-1
|
||||
|
||||
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||
alt="An overview of EO-1"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=eo1` configuration through LeRobot
|
||||
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||
|
||||
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EO-1 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1]"
|
||||
```
|
||||
|
||||
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1,libero]"
|
||||
```
|
||||
|
||||
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EO-1 expects a LeRobot dataset with:
|
||||
|
||||
- At least one visual observation, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction through the dataset `task` field
|
||||
|
||||
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=eo1
|
||||
```
|
||||
|
||||
By default, a new EO-1 policy initializes its backbone from:
|
||||
|
||||
```python
|
||||
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||
```
|
||||
|
||||
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-eo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=eo1 \
|
||||
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.attn_implementation=sdpa \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--output_dir=./outputs/eo1_training \
|
||||
--job_name=eo1_training \
|
||||
--steps=300000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||
|
||||
## Evaluation
|
||||
|
||||
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Processing
|
||||
|
||||
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||
|
||||
### State and Action Dimensions
|
||||
|
||||
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||
|
||||
### Attention Backend
|
||||
|
||||
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||
|
||||
## References
|
||||
|
||||
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{eo1,
|
||||
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||
journal={arXiv preprint},
|
||||
year={2025},
|
||||
url={https://arxiv.org/abs/2508.21112}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||
@@ -0,0 +1,186 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
|
||||
|
||||
## Model Overview
|
||||
|
||||
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=evo1` configuration through LeRobot
|
||||
- InternVL3 image/text embedding with optional FlashAttention fallback
|
||||
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with standard policy inference APIs
|
||||
|
||||
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EVO1 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1]"
|
||||
```
|
||||
|
||||
For LIBERO evaluation, install the LIBERO extra as well:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1,libero]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint.
|
||||
|
||||
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EVO1 expects a LeRobot dataset with:
|
||||
|
||||
- One to `policy.max_views` visual observations, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
|
||||
|
||||
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EVO1 in a LeRobot configuration, specify:
|
||||
|
||||
```python
|
||||
policy.type=evo1
|
||||
```
|
||||
|
||||
By default, a new EVO1 policy initializes its VLM from:
|
||||
|
||||
```python
|
||||
policy.vlm_model_name=OpenGVLab/InternVL3-1B
|
||||
```
|
||||
|
||||
Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-evo1-checkpoint
|
||||
```
|
||||
|
||||
The converted LIBERO checkpoint used for this PR is available at:
|
||||
|
||||
```python
|
||||
policy.path=javadcc/evo1-libero-lerobot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Stage 1
|
||||
|
||||
Stage 1 freezes the VLM and trains the action head:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=evo1 \
|
||||
--policy.training_stage=stage1 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=5000 \
|
||||
--output_dir=./outputs/evo1_stage1
|
||||
```
|
||||
|
||||
### Stage 2
|
||||
|
||||
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
|
||||
--policy.training_stage=stage2 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=80000 \
|
||||
--output_dir=./outputs/evo1_stage2
|
||||
```
|
||||
|
||||
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
|
||||
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
|
||||
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
|
||||
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
|
||||
every finetuning flag.
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
|
||||
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
|
||||
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
|
||||
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
|
||||
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
|
||||
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
|
||||
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
|
||||
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
|
||||
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
|
||||
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.max_state_dim` | `24` | State padding dimension |
|
||||
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## Results
|
||||
|
||||
### LIBERO Object Checkpoint Conversion
|
||||
|
||||
The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot)
|
||||
is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against
|
||||
the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing.
|
||||
|
||||
| Checkpoint | Suite | Episodes | Success Rate |
|
||||
| ---------------------------- | --------------- | ---------------- | ------------ |
|
||||
| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
|
||||
For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical
|
||||
pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step
|
||||
(`max_abs_diff=0.0`).
|
||||
|
||||
The published checkpoint expects the raw LIBERO camera feature names
|
||||
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
|
||||
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
|
||||
instead of the default `image`/`image2` mapping:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=javadcc/evo1-libero-lerobot \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
|
||||
--env.observation_height=448 \
|
||||
--env.observation_width=448 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
|
||||
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
|
||||
@@ -0,0 +1,98 @@
|
||||
# Compute HW Guide for LeRobot Training
|
||||
|
||||
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
|
||||
|
||||
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
|
||||
|
||||
## Memory by policy group
|
||||
|
||||
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30–100% over a forward+backward pass alone.
|
||||
|
||||
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
|
||||
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
|
||||
| Light BC | `act`, `vqbet`, `tdmpc` | ~2–6GB | Laptop GPU (RTX 3060), L4, A10G |
|
||||
| Diffusion | `diffusion`, `multi_task_dit` | ~8–14GB | RTX 4070+ / L4 / A10G |
|
||||
| Small VLA | `smolvla` | ~10–16GB | RTX 4080+ / L4 / A10G |
|
||||
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~24–40GB | A100 40 GB+ (24 GB tight at BS 1) |
|
||||
| Multimodal | `groot`, `eo1` | ~24–40GB | A100 40 GB+ |
|
||||
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
|
||||
|
||||
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
|
||||
|
||||
## Training time
|
||||
|
||||
Robotics imitation learning typically converges in **5–10 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
|
||||
|
||||
```text
|
||||
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
|
||||
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
|
||||
total_steps = epochs × steps_per_epoch
|
||||
wall_clock ≈ total_steps × per_step_time
|
||||
```
|
||||
|
||||
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
|
||||
|
||||
### Common scenarios
|
||||
|
||||
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
|
||||
|
||||
| Setup | Policy | Batch | Wall-clock |
|
||||
| ------------------------------------ | -------------- | ----- | ---------: |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~30–60min |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~2–4h |
|
||||
| Single L4 / A10G (24 GB) | `act` | 8 | ~1–2h |
|
||||
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~3–6h |
|
||||
| Single A100 40 GB | `smolvla` | 16 | ~1–2h |
|
||||
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~4–8h |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~30–60min |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~1–2h |
|
||||
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~6–14h |
|
||||
|
||||
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
|
||||
|
||||
### Multi-GPU matters a lot
|
||||
|
||||
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
|
||||
|
||||
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
|
||||
|
||||
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
|
||||
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
|
||||
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
|
||||
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
|
||||
|
||||
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
|
||||
|
||||
### Schedule and checkpoints
|
||||
|
||||
If you shorten training (e.g. 5k–10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
|
||||
|
||||
## Where to run
|
||||
|
||||
VRAM is the first filter. Within a tier, pick by budget and availability — the `$`–`$$$$` columns are relative; check current pricing on the provider you actually use.
|
||||
|
||||
| Class | VRAM | Tier | Comfortable for |
|
||||
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
|
||||
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
|
||||
| L4 / A10G (cloud) | 24 GB | `$–$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
|
||||
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
|
||||
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
|
||||
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
|
||||
|
||||
### Hugging Face Jobs
|
||||
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
|
||||
|
||||
```bash
|
||||
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
|
||||
bash -c "nvidia-smi && lerobot-train \
|
||||
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
|
||||
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||
@@ -820,10 +820,10 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
@@ -926,7 +926,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
||||
|
||||
Some configuration values have a disproportionate impact on training stability and speed:
|
||||
|
||||
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||
|
||||
|
||||
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
|
||||
|
||||
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
||||
|
||||
### PyTorch CUDA variant (Linux only)
|
||||
|
||||
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
||||
<hfoptions id="cuda_variant">
|
||||
<hfoption id="uv-source">
|
||||
|
||||
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
|
||||
|
||||
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
|
||||
|
||||
To override for a different CUDA variant:
|
||||
|
||||
```bash
|
||||
uv pip install --force-reinstall torch torchvision \
|
||||
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip-conda">
|
||||
|
||||
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
|
||||
|
||||
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
|
||||
|
||||
To pick a specific CUDA variant:
|
||||
|
||||
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
|
||||
|
||||
```bash
|
||||
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
|
||||
pip install -e ".[all]" # source
|
||||
# — or —
|
||||
pip install lerobot # from PyPI
|
||||
```
|
||||
|
||||
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
|
||||
|
||||
```bash
|
||||
uv pip install --torch-backend cu128 lerobot
|
||||
```
|
||||
|
||||
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
|
||||
integration uses an InternVL3 vision-language backbone with a flow-matching
|
||||
action head, and supports staged training through the standard LeRobot policy
|
||||
APIs.
|
||||
|
||||
The upstream EVO1 project is available at
|
||||
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
|
||||
|
||||
```bibtex
|
||||
@misc{evo1,
|
||||
title = {EVO1},
|
||||
author = {{MINT-SJTU}},
|
||||
year = {2026},
|
||||
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
|
||||
}
|
||||
```
|
||||
+29
-28
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,12 +488,13 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
--sample_weighting.kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
--sample_weighting.kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
--sample_weighting.kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -550,8 +551,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -576,7 +577,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
# OMX Follower — Cube Pick And Place Example
|
||||
|
||||
This is an example of what is possible to do with LeRobot on a physical setup.
|
||||
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
|
||||
|
||||
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
|
||||
|
||||
## Hardware
|
||||
|
||||
| Component | Value |
|
||||
| --------- | ------------------------------------ |
|
||||
| Robot | OMX Follower |
|
||||
| Cameras | 2× OpenCV cameras (wrist + top-down) |
|
||||
|
||||
## Scripts
|
||||
|
||||
| Script | Purpose |
|
||||
| ---------------------- | --------------------------------------------------------------- |
|
||||
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
|
||||
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
|
||||
|
||||
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
|
||||
|
||||
```bash
|
||||
export ROBOT_PORT=/dev/ttyACM0
|
||||
export TELEOP_PORT=/dev/ttyACM1
|
||||
export HF_USERNAME=<your_hf_username>
|
||||
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
|
||||
```
|
||||
|
||||
## Step 1 — Collect Data
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=$TELEOP_PORT \
|
||||
--teleop.id=omx_leader \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
### Bonus Auto-Collect script
|
||||
|
||||
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
|
||||
|
||||
```bash
|
||||
python -m examples.omx.record_grab \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Each episode:
|
||||
|
||||
1. The arm grabs the cube from the center of the workspace and places it at a random position.
|
||||
2. The arm returns to HOME.
|
||||
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
|
||||
|
||||
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
|
||||
|
||||
## Step 2 — Train
|
||||
|
||||
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/omx_pickandplace_act \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
|
||||
--steps=20000 \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
|
||||
|
||||
## Step 3 — Rollout
|
||||
|
||||
Use the `lerobot-rollout` CLI with base strategy:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
```
|
||||
|
||||
For continuous recording with automatic upload (sentry mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=10 \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
|
||||
```
|
||||
|
||||
## Environment Reset Utility
|
||||
|
||||
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
|
||||
|
||||
`reset_environment.py` can be run standalone to prepare the workspace:
|
||||
|
||||
```bash
|
||||
# Grab cube + place it at a random position on the left side
|
||||
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
|
||||
```
|
||||
|
||||
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
|
||||
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-record grab episodes for the OMX robot arm.
|
||||
|
||||
Each episode cycle:
|
||||
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
|
||||
2. HOME — return arm to home with gripper open
|
||||
3. record_grab — execute a targeted grab to the stored position while recording
|
||||
observations + actions to a LeRobotDataset
|
||||
|
||||
Usage (run from repo root):
|
||||
python -m examples.omx.record_grab \\
|
||||
--robot.type=omx_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--robot.id=omx_follower \\
|
||||
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
|
||||
--dataset.repo_id=<hf_username>/<dataset_name> \\
|
||||
--dataset.root=data/omx_grab \\
|
||||
--dataset.num_episodes=50 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
)
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.omx_follower import OmxFollower
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from .reset_environment import (
|
||||
APPROACH_SPEED,
|
||||
GRIPPER_CLOSE_POS,
|
||||
HOME_POSE,
|
||||
PUSH_END_ELBOW_FLEX,
|
||||
PUSH_END_SHOULDER_LIFT,
|
||||
PUSH_START_ELBOW_FLEX,
|
||||
PUSH_START_SHOULDER_LIFT,
|
||||
array_to_pose,
|
||||
grab_cube,
|
||||
horizontal_wrist_flex,
|
||||
move_to_pose,
|
||||
place_cube,
|
||||
pose_to_array,
|
||||
)
|
||||
|
||||
# ── Grab-episode motion parameters ────────────────────────────────────────────
|
||||
|
||||
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
|
||||
GRAB_RAISE_SL_OFFSET = 20.0
|
||||
GRAB_LOWER_SPEED = 20.0
|
||||
RECORD_SPEED = 30.0
|
||||
|
||||
# Pose the arm travels to after closing the gripper (cube held).
|
||||
GRAB_CARRY_POSE = {
|
||||
"shoulder_pan.pos": -23.0,
|
||||
"shoulder_lift.pos": 5.0,
|
||||
"elbow_flex.pos": 18.0,
|
||||
"wrist_flex.pos": -14.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
|
||||
# Cube-approach and carry poses are never jittered to preserve precision.
|
||||
_JITTER_LIMITS: dict[str, float] = {
|
||||
"shoulder_pan.pos": 5.0,
|
||||
"shoulder_lift.pos": 4.0,
|
||||
"elbow_flex.pos": 4.0,
|
||||
"wrist_flex.pos": 3.0,
|
||||
"wrist_roll.pos": 2.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
|
||||
"""Return a copy of pose with independent per-joint random perturbations."""
|
||||
return {
|
||||
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
|
||||
}
|
||||
|
||||
|
||||
def _random_stuck_pose(rng: np.random.Generator) -> dict:
|
||||
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
|
||||
|
||||
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
|
||||
table-safe envelope across the full sl range:
|
||||
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
|
||||
sl= 0 → ef ∈ [-25, 25] (mid reach)
|
||||
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
|
||||
wrist_flex is randomly offset from the horizontal value.
|
||||
"""
|
||||
pan = float(rng.uniform(-5.0, 35.0))
|
||||
sl = float(rng.uniform(-50.0, 30.0))
|
||||
|
||||
if sl <= 0.0:
|
||||
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
|
||||
ef_lo = alpha * -25.0 # 0 → -25
|
||||
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
|
||||
else:
|
||||
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
|
||||
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
|
||||
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
|
||||
|
||||
ef = float(rng.uniform(ef_lo, ef_hi))
|
||||
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf,
|
||||
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmxRecordGrabConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Fraction of episodes that start from a random stuck pose (gripper closed) to
|
||||
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
|
||||
recovery_prob: float = 0.5
|
||||
|
||||
|
||||
def record_episode_spline(
|
||||
robot: OmxFollower,
|
||||
waypoints: list[dict],
|
||||
speeds: list[float],
|
||||
dataset: LeRobotDataset,
|
||||
task: str,
|
||||
) -> None:
|
||||
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
|
||||
|
||||
Segment durations are parameterized from the maximum absolute joint delta
|
||||
between consecutive waypoints divided by the requested segment speed,
|
||||
producing non-uniform timing in joint space. Interior tangents are derived
|
||||
from the adjacent per-segment velocities, with clamped (zero-velocity)
|
||||
endpoints so the arm starts and stops smoothly. Each segment is cubic
|
||||
Hermite, giving C1 continuity at every waypoint.
|
||||
"""
|
||||
pts = [pose_to_array(w) for w in waypoints]
|
||||
n = len(pts)
|
||||
|
||||
# Steps and duration per segment
|
||||
n_steps_list = []
|
||||
timestamps = []
|
||||
for i in range(n - 1):
|
||||
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
|
||||
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
|
||||
n_steps_list.append(ns)
|
||||
timestamps.append(ns / dataset.fps)
|
||||
|
||||
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
|
||||
vels = [np.zeros_like(pts[0])]
|
||||
for i in range(1, n - 1):
|
||||
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
|
||||
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
|
||||
vels.append(0.5 * (v_prev + v_next))
|
||||
vels.append(np.zeros_like(pts[0]))
|
||||
|
||||
dt = 1.0 / dataset.fps
|
||||
for seg in range(n - 1):
|
||||
ns = n_steps_list[seg]
|
||||
if ns == 0:
|
||||
continue
|
||||
p0, p1 = pts[seg], pts[seg + 1]
|
||||
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
|
||||
m0 = vels[seg] * timestamps[seg]
|
||||
m1 = vels[seg + 1] * timestamps[seg]
|
||||
|
||||
for step in range(1, ns + 1):
|
||||
t = step / ns
|
||||
h00 = 2 * t**3 - 3 * t**2 + 1
|
||||
h10 = t**3 - 2 * t**2 + t
|
||||
h01 = -2 * t**3 + 3 * t**2
|
||||
h11 = t**3 - t**2
|
||||
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
|
||||
action = array_to_pose(commanded)
|
||||
robot.send_action(action)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
def record_grab_episode(
|
||||
robot: OmxFollower,
|
||||
dataset: LeRobotDataset,
|
||||
pan: float,
|
||||
t: float,
|
||||
task: str,
|
||||
recovery_start: bool = False,
|
||||
) -> None:
|
||||
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
|
||||
|
||||
Normal sequence (initial HOME move is NOT recorded):
|
||||
HOME → raised approach above cube → lower → close gripper
|
||||
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
|
||||
|
||||
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
|
||||
(gripper closed) without recording, then recording begins from there:
|
||||
stuck_pose → raised approach above cube → [normal grab sequence from there]
|
||||
|
||||
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
|
||||
"""
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
sl_raised = sl - GRAB_RAISE_SL_OFFSET
|
||||
wf_horizontal = horizontal_wrist_flex(sl, ef)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
if recovery_start:
|
||||
stuck_pose = _random_stuck_pose(rng)
|
||||
logger.info(f"Recovery start: {stuck_pose}")
|
||||
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
|
||||
first_waypoints = [stuck_pose]
|
||||
first_speeds = []
|
||||
else:
|
||||
jittery_start = _jitter_pose(HOME_POSE, rng)
|
||||
move_to_pose(robot, jittery_start, APPROACH_SPEED)
|
||||
first_waypoints = [jittery_start]
|
||||
first_speeds = []
|
||||
|
||||
waypoints = first_waypoints + [
|
||||
{ # raised approach: arm above cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # lower onto cube — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # close gripper — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
_jitter_pose(
|
||||
{ # raise with cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
_jitter_pose(
|
||||
{ # retract: fold arm toward HOME before sweeping to carry zone
|
||||
"shoulder_pan.pos": pan * 0.25,
|
||||
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
|
||||
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
GRAB_CARRY_POSE, # no jitter: target drop zone
|
||||
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
|
||||
HOME_POSE,
|
||||
]
|
||||
speeds = first_speeds + [
|
||||
RECORD_SPEED, # (HOME →) raised approach
|
||||
GRAB_LOWER_SPEED, # raised approach → lower
|
||||
GRAB_LOWER_SPEED, # lower → close gripper
|
||||
RECORD_SPEED, # close gripper → raise
|
||||
RECORD_SPEED, # raise → retract
|
||||
RECORD_SPEED, # retract → carry pose
|
||||
RECORD_SPEED, # carry pose → drop
|
||||
RECORD_SPEED, # drop → HOME
|
||||
]
|
||||
|
||||
record_episode_spline(robot, waypoints, speeds, dataset, task)
|
||||
|
||||
# Dwell at HOME for ~0.5 s before next episode
|
||||
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
|
||||
dt = 1.0 / dataset.fps
|
||||
for _ in range(int(dataset.fps * 0.5)):
|
||||
robot.send_action(HOME_POSE)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
dataset.add_frame({**obs_frame, **home_action, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger.info(pformat(cfg))
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
use_videos = cfg.dataset.video
|
||||
|
||||
teleop_action_processor, _, robot_obs_processor = make_default_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_obs_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
)
|
||||
|
||||
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||
dataset = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
else:
|
||||
cfg.dataset.stamp_repo_id()
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
with VideoEncodingManager(dataset):
|
||||
for episode_idx in range(cfg.dataset.num_episodes):
|
||||
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
|
||||
|
||||
logger.info("Step 1: grabbing and placing cube...")
|
||||
grab_cube(robot)
|
||||
pan, t = place_cube(robot)
|
||||
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
|
||||
|
||||
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
|
||||
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
|
||||
record_grab_episode(
|
||||
robot,
|
||||
dataset,
|
||||
pan,
|
||||
t,
|
||||
cfg.dataset.single_task,
|
||||
recovery_start=recovery_start,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
logger.info(f"Episode {episode_idx + 1} saved.")
|
||||
|
||||
finally:
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
record_grab()
|
||||
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-reset and cube-grab utility for the OMX robot arm.
|
||||
|
||||
Provides:
|
||||
- grab_cube(robot): sweep workspace, center cube, close gripper
|
||||
- place_cube(robot): carry cube to a random position, release
|
||||
|
||||
Standalone usage (run from repo root):
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
|
||||
|
||||
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
|
||||
|
||||
To read current joint values for calibration, add after robot.connect():
|
||||
obs = robot.get_observation()
|
||||
print({k: round(obs[k], 1) for k in JOINT_NAMES})
|
||||
robot.disconnect(); raise SystemExit
|
||||
|
||||
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
|
||||
Linear interpolation preserves this constraint between any two poses that satisfy it.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Poses ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
HOME_POSE = {
|
||||
"shoulder_pan.pos": 0.0,
|
||||
"shoulder_lift.pos": -50.0,
|
||||
"elbow_flex.pos": 50.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
SWEEP_WAYPOINTS = [
|
||||
{
|
||||
"shoulder_pan.pos": -60.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -20.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": -30.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": 20.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -55.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
]
|
||||
|
||||
# ── Motion parameters ─────────────────────────────────────────────────────────
|
||||
|
||||
CONTROL_HZ = 30
|
||||
APPROACH_SPEED = 50.0
|
||||
SWEEP_SPEED = 40.0
|
||||
|
||||
# ── Grab-sequence parameters ──────────────────────────────────────────────────
|
||||
|
||||
GRAB_PAN = 0.0
|
||||
SWEEP_LEFT_PAN = -60.0
|
||||
SWEEP_RIGHT_PAN = 60.0
|
||||
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
|
||||
SWEEP_END_PAN_RANGE = (15.0, 20.0)
|
||||
|
||||
SWEEP_LOW_SHOULDER_LIFT = 50.0
|
||||
SWEEP_LOW_ELBOW_FLEX_START = -60.0
|
||||
SWEEP_LOW_ELBOW_FLEX_END = -55.0
|
||||
|
||||
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
|
||||
|
||||
PUSH_START_SHOULDER_LIFT = 0.0
|
||||
PUSH_START_ELBOW_FLEX = 45.0
|
||||
PUSH_END_SHOULDER_LIFT = 50.0
|
||||
PUSH_END_ELBOW_FLEX = -50.0
|
||||
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
|
||||
# Does not affect the grab-target interpolation in record_grab.py.
|
||||
PUSH_RAISE_OFFSET = 5.0
|
||||
|
||||
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
|
||||
GRIPPER_CLOSE_POS = 50.0
|
||||
|
||||
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
|
||||
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
|
||||
|
||||
JOINT_NAMES = [
|
||||
"shoulder_pan.pos",
|
||||
"shoulder_lift.pos",
|
||||
"elbow_flex.pos",
|
||||
"wrist_flex.pos",
|
||||
"wrist_roll.pos",
|
||||
"gripper.pos",
|
||||
]
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def pose_to_array(pose: dict) -> np.ndarray:
|
||||
return np.array([pose[k] for k in JOINT_NAMES])
|
||||
|
||||
|
||||
def array_to_pose(arr: np.ndarray) -> dict:
|
||||
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
|
||||
|
||||
|
||||
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
|
||||
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
|
||||
|
||||
|
||||
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
|
||||
sl = SWEEP_LOW_SHOULDER_LIFT
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
|
||||
def _high_sweep_pose(pan: float) -> dict:
|
||||
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
|
||||
|
||||
|
||||
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": shoulder_lift,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": gripper,
|
||||
}
|
||||
|
||||
|
||||
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
|
||||
"""Interpolate from current position to target at the given speed (units/s)."""
|
||||
obs = robot.get_observation()
|
||||
current = np.array([obs[k] for k in JOINT_NAMES])
|
||||
goal = pose_to_array(target)
|
||||
|
||||
max_distance = float(np.max(np.abs(goal - current)))
|
||||
if max_distance < 0.5:
|
||||
return
|
||||
|
||||
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
|
||||
dt = 1.0 / CONTROL_HZ
|
||||
for step in range(1, n_steps + 1):
|
||||
t = step / n_steps
|
||||
robot.send_action(array_to_pose(current + t * (goal - current)))
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
# ── Sequences ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def grab_cube(robot: Robot) -> None:
|
||||
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
for pan, end_pan in [
|
||||
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
|
||||
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
|
||||
]:
|
||||
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
|
||||
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
logger.info("Extending to push cube into gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
|
||||
SWEEP_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Closing gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Grab complete.")
|
||||
|
||||
|
||||
def place_cube(robot: Robot) -> tuple[float, float]:
|
||||
"""Carry the cube (gripper closed) to a random position on the left side, then release.
|
||||
|
||||
Returns:
|
||||
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
|
||||
"""
|
||||
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
|
||||
t = float(np.random.uniform(*PLACE_REACH_RANGE))
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
|
||||
|
||||
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
logger.info("Place complete.")
|
||||
return pan, t
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
|
||||
parser.add_argument("--port", default="/dev/ttyACM1")
|
||||
parser.add_argument("--robot_id", default="omx_follower")
|
||||
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
try:
|
||||
if args.mode == "grab":
|
||||
grab_cube(robot)
|
||||
elif args.mode == "grab_and_place":
|
||||
grab_cube(robot)
|
||||
place_cube(robot)
|
||||
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,175 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple SO100/SO101 leader-follower teleoperation with spacebar intervention toggle.
|
||||
|
||||
Modes:
|
||||
- Default (not intervening): follower holds its current position.
|
||||
The leader arm has torque ENABLED and mirrors the follower so there is no
|
||||
large position jump when intervention starts.
|
||||
- Intervention (SPACE pressed): leader torque DISABLED, human moves the leader
|
||||
freely, and the follower mirrors the leader joint-by-joint.
|
||||
|
||||
Usage:
|
||||
uv run python examples/so100_teleop/teleop.py
|
||||
|
||||
Controls:
|
||||
SPACE — toggle intervention on/off
|
||||
Ctrl+C — exit
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from threading import Event, Thread
|
||||
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO101Leader
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── pynput keyboard listener ─────────────────────────────────────────────────
|
||||
PYNPUT_AVAILABLE = True
|
||||
try:
|
||||
if "DISPLAY" not in os.environ and "linux" in sys.platform:
|
||||
raise ImportError("No DISPLAY set, pynput skipped.")
|
||||
from pynput import keyboard as pynput_keyboard
|
||||
except Exception:
|
||||
pynput_keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
|
||||
# ── Configure ports ──────────────────────────────────────────────────────────
|
||||
FOLLOWER_PORT = "/dev/ttyUSB0" # ← change to your follower port
|
||||
LEADER_PORT = "/dev/ttyUSB1" # ← change to your leader port
|
||||
FPS = 30
|
||||
|
||||
|
||||
def hold_position(robot) -> dict:
|
||||
"""Read current joint positions and write them back as the goal.
|
||||
|
||||
This prevents the motors from snapping to a stale Goal_Position register
|
||||
value (which can happen when torque is re-enabled after calibration).
|
||||
Returns the current position dict for reuse.
|
||||
"""
|
||||
current = robot.bus.sync_read("Present_Position")
|
||||
robot.bus.sync_write("Goal_Position", current)
|
||||
return {f"{motor}.pos": val for motor, val in current.items()}
|
||||
|
||||
|
||||
# ── Connect ───────────────────────────────────────────────────────────────────
|
||||
follower_config = SO101FollowerConfig(
|
||||
port=FOLLOWER_PORT,
|
||||
id="follower_arm",
|
||||
use_degrees=True,
|
||||
)
|
||||
leader_config = SOLeaderTeleopConfig(
|
||||
port=LEADER_PORT,
|
||||
id="leader_arm",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
follower = SO101Follower(follower_config)
|
||||
leader = SO101Leader(leader_config)
|
||||
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# ── CRITICAL: hold both arms at their current position before doing anything ─
|
||||
# configure() enables follower torque, and the Goal_Position register may contain
|
||||
# a stale value from a previous session. Writing current→goal prevents sudden motion.
|
||||
follower_current = hold_position(follower)
|
||||
leader_current = hold_position(leader) # leader torque is still off here, but sets the register
|
||||
|
||||
# ── Intervention state + keyboard listener ───────────────────────────────────
|
||||
is_intervening = False
|
||||
stop_event = Event()
|
||||
|
||||
|
||||
def _start_keyboard_listener():
|
||||
if not PYNPUT_AVAILABLE:
|
||||
logger.warning("pynput not available — spacebar toggle disabled.")
|
||||
return None
|
||||
|
||||
def on_press(key):
|
||||
global is_intervening
|
||||
if key == pynput_keyboard.Key.space:
|
||||
is_intervening = not is_intervening
|
||||
state = "INTERVENTION (leader → follower)" if is_intervening else "IDLE (follower holds)"
|
||||
print(f"\n[SPACE] {state}\n")
|
||||
|
||||
def listen():
|
||||
with pynput_keyboard.Listener(on_press=on_press) as listener:
|
||||
while not stop_event.is_set():
|
||||
time.sleep(0.05)
|
||||
listener.stop()
|
||||
|
||||
t = Thread(target=listen, daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
|
||||
kbd_thread = _start_keyboard_listener()
|
||||
|
||||
# Enable leader torque AFTER writing its goal to current position, so it holds in place.
|
||||
leader.bus.sync_write("Torque_Enable", 1)
|
||||
leader_torque_on = True
|
||||
|
||||
print("\nTeleoperation ready.")
|
||||
print(" SPACE → toggle intervention (leader controls follower)")
|
||||
print(" Ctrl+C → exit\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if is_intervening:
|
||||
# ── Intervention: leader torque OFF, follower mirrors leader ──────
|
||||
if leader_torque_on:
|
||||
leader.bus.sync_write("Torque_Enable", 0)
|
||||
leader_torque_on = False
|
||||
|
||||
leader_action = leader.get_action() # reads present leader joints
|
||||
follower.send_action(leader_action) # follower tracks leader
|
||||
|
||||
else:
|
||||
# ── Idle: leader torque ON, leader mirrors follower, follower holds
|
||||
if not leader_torque_on:
|
||||
# Before re-enabling torque, set the leader's goal to its current
|
||||
# position so it doesn't snap to the follower position suddenly.
|
||||
hold_position(leader)
|
||||
leader.bus.sync_write("Torque_Enable", 1)
|
||||
leader_torque_on = True
|
||||
|
||||
follower_obs = follower.get_observation()
|
||||
# Command leader to match follower (so next intervention has no jump)
|
||||
goal_pos = {motor: follower_obs[f"{motor}.pos"] for motor in leader.bus.motors}
|
||||
leader.bus.sync_write("Goal_Position", goal_pos)
|
||||
# Follower holds — no send_action call
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
finally:
|
||||
stop_event.set()
|
||||
leader.bus.sync_write("Torque_Enable", 0)
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
@@ -1,365 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
ProcessorStepRegistry,
|
||||
RobotAction,
|
||||
RobotActionProcessorStep,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
create_transition,
|
||||
identity_transition,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsRLStep,
|
||||
)
|
||||
from lerobot.robots.so101_follower.config_so101_follower import SO101FollowerConfig
|
||||
from lerobot.robots.so101_follower.so101_follower import SO101Follower
|
||||
from lerobot.teleoperators.so101_leader.config_so101_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so101_leader.so101_leader import SO101Leader
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
|
||||
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
|
||||
"""Reset robot arm to target position using smooth trajectory."""
|
||||
current_position_dict = robot_arm.bus.sync_read("Present_Position")
|
||||
current_position = np.array(
|
||||
[current_position_dict[name] for name in current_position_dict],
|
||||
dtype=np.float32,
|
||||
)
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
action_dict = dict(zip(current_position_dict, pose, strict=False))
|
||||
robot_arm.bus.sync_write("Goal_Position", action_dict)
|
||||
precise_sleep(0.015)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogRobotAction(RobotActionProcessorStep):
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
print(f"Robot action: {action}")
|
||||
return action
|
||||
|
||||
def transform_features(self, features):
|
||||
# features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
|
||||
# type=FeatureType.ACTION, shape=(len(self.motor_names),)
|
||||
# )
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_target_action")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEETargetAction(RobotActionProcessorStep):
|
||||
"""
|
||||
Computes the end-effector pose from joint positions using forward kinematics (FK).
|
||||
|
||||
This step is typically used to add the robot's Cartesian pose to the observation space,
|
||||
which can be useful for visualization or as an input to a policy.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
end_effector_step_sizes: dict
|
||||
max_gripper_pos: float
|
||||
use_ik_solution: bool = False
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
# return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names)
|
||||
teleop_action = action
|
||||
raw_joint_pos = self.transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
leader_pos = np.array([teleop_action[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
leader_ee = self.kinematics.forward_kinematics(leader_pos)
|
||||
|
||||
if self.use_ik_solution and "IK_solution" in self.transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
follower_pos = transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
else:
|
||||
follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
follower_ee = self.kinematics.forward_kinematics(follower_pos)
|
||||
|
||||
follower_ee_pos = follower_ee[:3, 3]
|
||||
follower_ee_rvec = Rotation.from_matrix(follower_ee[:3, :3]).as_rotvec()
|
||||
# follower_gripper_pos = raw_joint_pos["gripper.pos"]
|
||||
follower_gripper_pos = follower_pos[-1] # assuming gripper is the last motor
|
||||
|
||||
leader_ee_pos = leader_ee[:3, 3]
|
||||
leader_ee_rvec = Rotation.from_matrix(leader_ee[:3, :3]).as_rotvec()
|
||||
leader_gripper_pos = np.clip(
|
||||
teleop_action["gripper.pos"], -self.max_gripper_pos, self.max_gripper_pos
|
||||
)
|
||||
|
||||
print("f pos:", follower_ee_pos)
|
||||
print("l pos:", leader_ee_pos)
|
||||
|
||||
print("f rvec:", follower_ee_rvec)
|
||||
print("l rvec:", leader_ee_rvec)
|
||||
|
||||
# follower_ee_pos = follower_ee[:3, 3]
|
||||
# follower_ee_rvec = Rotation.from_matrix(follower_ee[:3, :3]).as_rotvec()
|
||||
|
||||
delta_pos = leader_ee_pos - follower_ee_pos
|
||||
|
||||
# For rotation: compute relative rotation from follower to leader
|
||||
# R_leader = R_follower * R_delta => R_delta = R_follower^T * R_leader
|
||||
r_delta = follower_ee[:3, :3].T @ leader_ee[:3, :3]
|
||||
delta_rvec = Rotation.from_matrix(r_delta).as_rotvec()
|
||||
delta_gripper = leader_gripper_pos - follower_gripper_pos
|
||||
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = follower_ee[:3, :3] @ r_delta
|
||||
desired[:3, 3] = follower_ee[:3, 3] + delta_pos
|
||||
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
|
||||
assert np.allclose(pos, leader_ee_pos), "Position delta computation error"
|
||||
assert np.allclose(tw, leader_ee_rvec), "Orientation delta computation error"
|
||||
assert np.isclose(follower_gripper_pos + delta_gripper, leader_gripper_pos), (
|
||||
"Gripper delta computation error"
|
||||
)
|
||||
|
||||
# Normalize the action to the range [-1, 1]
|
||||
delta_pos = delta_pos / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["x"],
|
||||
self.end_effector_step_sizes["y"],
|
||||
self.end_effector_step_sizes["z"],
|
||||
]
|
||||
)
|
||||
delta_rvec = delta_rvec / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["wx"],
|
||||
self.end_effector_step_sizes["wy"],
|
||||
self.end_effector_step_sizes["wz"],
|
||||
]
|
||||
)
|
||||
|
||||
# Check if any of the normalized deltas exceed 1.0
|
||||
|
||||
max_normalized_pos = max(
|
||||
abs(delta_pos[0]),
|
||||
abs(delta_pos[1]),
|
||||
abs(delta_pos[2]),
|
||||
)
|
||||
|
||||
max_normalized_rot = max(
|
||||
abs(delta_rvec[0]),
|
||||
abs(delta_rvec[1]),
|
||||
abs(delta_rvec[2]),
|
||||
)
|
||||
|
||||
# Use the same scaling factor for both position and rotation
|
||||
max_normalized = max(max_normalized_pos, max_normalized_rot)
|
||||
if max_normalized > 1.0:
|
||||
print(f"Warning: EE delta too large, scaling. Max normalized delta: {max_normalized_pos}")
|
||||
print(f"Original delta_pos: {delta_pos}, delta_rvec: {delta_rvec}")
|
||||
# Scale proportionally
|
||||
delta_pos = delta_pos / max_normalized
|
||||
delta_rvec = delta_rvec / max_normalized
|
||||
|
||||
new_action = {}
|
||||
new_action["enabled"] = True
|
||||
new_action["target_x"] = float(delta_pos[0])
|
||||
new_action["target_y"] = float(delta_pos[1])
|
||||
new_action["target_z"] = float(delta_pos[2])
|
||||
new_action["target_wx"] = float(delta_rvec[0])
|
||||
new_action["target_wy"] = float(delta_rvec[1])
|
||||
new_action["target_wz"] = float(delta_rvec[2])
|
||||
new_action["gripper_vel"] = float(
|
||||
np.clip(delta_gripper, -self.max_gripper_pos, self.max_gripper_pos) / self.max_gripper_pos
|
||||
)
|
||||
return new_action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# TODO: implement feature transformation
|
||||
return features
|
||||
|
||||
|
||||
FPS = 20
|
||||
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO101FollowerConfig(port="/dev/usb_follower_arm_a", id="follower_arm_a", use_degrees=True)
|
||||
leader_config = SO101LeaderConfig(port="/dev/usb_leader_arm_a", id="leader_arm_a", use_degrees=True)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO101Follower(follower_config)
|
||||
leader = SO101Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="../SO-ARM100/Simulation/SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="../SO-ARM100/Simulation/SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
end_effector_step_sizes = {
|
||||
"x": 0.004,
|
||||
"y": 0.004,
|
||||
"z": 0.004,
|
||||
"wx": 5 * np.pi / 180,
|
||||
"wy": 5 * np.pi / 180,
|
||||
"wz": 5 * np.pi / 180,
|
||||
}
|
||||
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
LogRobotAction(),
|
||||
ForwardKinematicsJointsToEETargetAction(
|
||||
kinematics=leader_kinematics_solver,
|
||||
motor_names=list(leader.bus.motors.keys()),
|
||||
end_effector_step_sizes=end_effector_step_sizes,
|
||||
max_gripper_pos=30.0,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
LogRobotAction(),
|
||||
],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
LogRobotAction(),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=follower_kinematics_solver,
|
||||
# end_effector_step_sizes={"x": 0.006, "y": 0.01, "z": 0.005},
|
||||
end_effector_step_sizes=end_effector_step_sizes,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
use_latched_reference=False,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
LogRobotAction(),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={
|
||||
"min": [-0.05, -0.55, -0.0075],
|
||||
"max": [0.55, 0.55, 0.55],
|
||||
},
|
||||
# end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.05,
|
||||
),
|
||||
LogRobotAction(),
|
||||
GripperVelocityToJoint(
|
||||
clip_max=30.0,
|
||||
speed_factor=0.2,
|
||||
discrete_gripper=False,
|
||||
scale_velocity=True,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
LogRobotAction(),
|
||||
InverseKinematicsRLStep(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
LogRobotAction(),
|
||||
],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
reset_pose = [0.0, 10, 20, 60.00, 90.00, 10.00]
|
||||
|
||||
start_time = time.perf_counter()
|
||||
reset_follower_position(follower, np.array(reset_pose))
|
||||
reset_follower_position(leader, np.array(reset_pose))
|
||||
precise_sleep(5.0 - (time.perf_counter() - start_time))
|
||||
# time.sleep(10)
|
||||
leader.bus.sync_write("Torque_Enable", 0)
|
||||
|
||||
# Init rerun viewer
|
||||
# init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
transition = None
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
print("New loop iteration")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
if transition is None:
|
||||
transition = create_transition(action=leader_joints_obs, observation=robot_obs)
|
||||
else:
|
||||
transition = create_transition(
|
||||
action=leader_joints_obs,
|
||||
observation=robot_obs,
|
||||
complementary_data=transition.get(TransitionKey.COMPLEMENTARY_DATA),
|
||||
)
|
||||
|
||||
transition = leader_to_ee(transition)
|
||||
leader_ee_act = transition[TransitionKey.ACTION]
|
||||
|
||||
# teleop EE -> robot joints
|
||||
transition = create_transition(
|
||||
action=leader_ee_act,
|
||||
observation=robot_obs,
|
||||
complementary_data=transition.get(TransitionKey.COMPLEMENTARY_DATA),
|
||||
)
|
||||
transition = ee_to_follower_joints(transition)
|
||||
follower_joints_act = transition[TransitionKey.ACTION]
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
# log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
policy_learner: SACPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -40,9 +40,8 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -84,26 +83,24 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
|
||||
stats = algorithm.update(batch_iter())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -116,7 +113,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
policy_actor: SACPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -147,15 +144,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
# Get action from policy
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
@@ -264,14 +261,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = GaussianActorConfig(
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||
|
||||
|
||||
def main():
|
||||
@@ -22,10 +22,10 @@ def main():
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
# Make reward model, preprocessor, and optimizer
|
||||
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
loss, output_dict = reward_model.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -58,8 +58,8 @@ def main():
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
# You can now save the trained reward model.
|
||||
reward_model.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+22
-4
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
|
||||
dependencies = [
|
||||
# Core ML
|
||||
"torch>=2.7,<2.11.0",
|
||||
"torchvision>=0.22.0,<0.26.0",
|
||||
"torch>=2.7,<2.12.0",
|
||||
"torchvision>=0.22.0,<0.27.0",
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
@@ -99,7 +99,7 @@ dataset = [
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
||||
"torchcodec>=0.3.0,<0.12.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
]
|
||||
training = [
|
||||
@@ -128,7 +128,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -194,6 +194,8 @@ groot = [
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -257,6 +259,7 @@ all = [
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[evo1]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -292,6 +295,20 @@ lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
|
||||
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
|
||||
# uv pip install --force-reinstall torch torchvision \
|
||||
# --index-url https://download.pytorch.org/whl/cu130
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
@@ -333,6 +350,7 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -99,7 +99,6 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -41,8 +41,12 @@ def cfg_to_group(
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||
else:
|
||||
trainable_tag = f"policy:{cfg.policy.type}"
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
trainable_tag,
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""Base configuration for reward models.
|
||||
|
||||
Args:
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
"""
|
||||
|
||||
# Reuses PolicyFeature
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
# Hub metadata
|
||||
license: str | None = None
|
||||
tags: list[str] | None = None
|
||||
private: bool | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**reward_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
# HACK: Parse the original config to get the config subclass, so that we can
|
||||
# apply cli overrides.
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config.pop("type", None)
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(config, f)
|
||||
config_file = f.name
|
||||
|
||||
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -26,18 +28,57 @@ from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||
legacy_fields = (
|
||||
"use_rabc",
|
||||
"rabc_progress_path",
|
||||
"rabc_kappa",
|
||||
"rabc_epsilon",
|
||||
"rabc_head_mode",
|
||||
)
|
||||
if not any(key in config for key in legacy_fields):
|
||||
return None
|
||||
|
||||
migrated_config = dict(config)
|
||||
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||
|
||||
# New configs may already define sample_weighting explicitly. In that case,
|
||||
# legacy fields are ignored after being stripped from the payload.
|
||||
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||
if rabc_progress_path is not None:
|
||||
sample_weighting["progress_path"] = rabc_progress_path
|
||||
if rabc_kappa is not None:
|
||||
sample_weighting["kappa"] = rabc_kappa
|
||||
if rabc_epsilon is not None:
|
||||
sample_weighting["epsilon"] = rabc_epsilon
|
||||
if rabc_head_mode is not None:
|
||||
sample_weighting["head_mode"] = rabc_head_mode
|
||||
migrated_config["sample_weighting"] = sample_weighting
|
||||
|
||||
return migrated_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
reward_model: RewardModelConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
@@ -72,27 +113,41 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def is_reward_model_training(self) -> bool:
|
||||
"""True when the config targets a reward model rather than a policy."""
|
||||
return self.reward_model is not None
|
||||
|
||||
@property
|
||||
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||
"""Return whichever config (policy or reward_model) is active."""
|
||||
if self.is_reward_model_training:
|
||||
return self.reward_model # type: ignore[return-value]
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
|
||||
if reward_model_path:
|
||||
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||
self.reward_model = RewardModelConfig.from_pretrained(
|
||||
reward_model_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
@@ -108,18 +163,22 @@ class TrainPipelineConfig(HubMixin):
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
if self.policy is None and self.reward_model is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
"Neither policy nor reward_model is configured. "
|
||||
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
@@ -137,26 +196,16 @@ class TrainPipelineConfig(HubMixin):
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
self.optimizer = active_cfg.get_optimizer_preset()
|
||||
self.scheduler = active_cfg.get_scheduler_preset()
|
||||
|
||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
||||
raise ValueError(
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
return ["policy", "reward_model"]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
@@ -207,5 +256,23 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||
if config_file is not None and config_file.endswith(".json"):
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
if migrated_config is not None:
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(migrated_config, f)
|
||||
config_file = f.name
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
df["index"] = df["index"] + dst_meta.info.total_frames
|
||||
|
||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||
@@ -225,9 +225,9 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
|
||||
return df
|
||||
|
||||
@@ -237,8 +237,8 @@ def aggregate_datasets(
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -313,8 +313,8 @@ def aggregate_datasets(
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
dst_meta.info.total_episodes += src_meta.total_episodes
|
||||
dst_meta.info.total_frames += src_meta.total_frames
|
||||
|
||||
finalize_aggregation(dst_meta, all_metadata)
|
||||
logging.info("Aggregation complete.")
|
||||
@@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
|
||||
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
|
||||
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
|
||||
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
|
||||
@@ -37,13 +37,11 @@ from .io_utils import (
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
get_safe_version,
|
||||
has_legacy_hub_download_metadata,
|
||||
@@ -228,7 +226,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
return packaging.version.parse(self.info.codebase_version)
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
"""Return the relative parquet file path for the given episode index.
|
||||
@@ -283,27 +281,27 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
return self.info.data_path
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
return self.info.video_path
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
return self.info.robot_type
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
return self.info.fps
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
return self.info.features
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
@@ -333,32 +331,32 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
return self.info.total_episodes
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
return self.info.total_frames
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
return self.info.total_tasks
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
return self.info.chunks_size
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
return self.info.data_files_size_in_mb
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
return self.info.video_files_size_in_mb
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
@@ -502,10 +500,10 @@ class LeRobotDatasetMetadata:
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info.total_episodes += 1
|
||||
self.info.total_frames += episode_length
|
||||
self.info.total_tasks = len(self.tasks)
|
||||
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -524,7 +522,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -546,17 +544,17 @@ class LeRobotDatasetMetadata:
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info["chunks_size"] = chunks_size
|
||||
self.info.chunks_size = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
self.info.data_files_size_in_mb = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
self.info.video_files_size_in_mb = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
@@ -653,7 +651,7 @@ class LeRobotDatasetMetadata:
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
write_info(obj.info, obj.root)
|
||||
obj.revision = None
|
||||
obj._pq_writer = None
|
||||
obj.latest_episode = None
|
||||
|
||||
@@ -897,14 +897,10 @@ def _copy_and_reindex_episodes_metadata(
|
||||
|
||||
dst_meta.finalize()
|
||||
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episode_mapping),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episode_mapping)}"},
|
||||
}
|
||||
)
|
||||
dst_meta.info.total_episodes = len(episode_mapping)
|
||||
dst_meta.info.total_frames = total_frames
|
||||
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
|
||||
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
if not all_stats:
|
||||
@@ -1069,21 +1065,20 @@ def _copy_episodes_metadata_and_stats(
|
||||
if episodes_dir.exists():
|
||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
||||
}
|
||||
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
|
||||
dst_meta.info.total_frames = src_dataset.meta.total_frames
|
||||
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
|
||||
# Preserve original splits if available, otherwise create default
|
||||
dst_meta.info.splits = (
|
||||
src_dataset.meta.info.splits
|
||||
if src_dataset.meta.info.splits
|
||||
else {"train": f"0:{src_dataset.meta.total_episodes}"}
|
||||
)
|
||||
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
||||
"info", {}
|
||||
)
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
@@ -1525,7 +1520,7 @@ def modify_tasks(
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
dataset.meta.info.total_tasks = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
@@ -1858,10 +1853,10 @@ def convert_image_to_video_dataset(
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
new_meta.info.total_episodes = len(episode_indices)
|
||||
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
@@ -1870,7 +1865,7 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from pprint import pformat
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||
``{observation,action,reward}_delta_indices`` properties used below.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
|
||||
@@ -28,6 +28,7 @@ from .utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DatasetInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,8 +79,8 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
) -> DatasetInfo:
|
||||
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
@@ -87,25 +88,24 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
|
||||
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
|
||||
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
DatasetInfo: A typed dataset information object with initial metadata.
|
||||
"""
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
return DatasetInfo(
|
||||
codebase_version=codebase_version,
|
||||
fps=fps,
|
||||
features=features,
|
||||
robot_type=robot_type,
|
||||
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path=DEFAULT_DATA_PATH,
|
||||
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
)
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
|
||||
@@ -39,6 +39,7 @@ from .utils import (
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
DatasetInfo,
|
||||
serialize_dict,
|
||||
)
|
||||
|
||||
@@ -115,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
def write_info(info: DatasetInfo, local_dir: Path) -> None:
|
||||
write_json(info.to_dict(), local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
def load_info(local_dir: Path) -> DatasetInfo:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information dictionary.
|
||||
DatasetInfo: The typed dataset information object.
|
||||
"""
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
raw = load_json(local_dir / INFO_PATH)
|
||||
return DatasetInfo.from_dict(raw)
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
|
||||
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
return self._datasets[0].meta.info.fps
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
return len(self._datasets[0].meta.video_keys) > 0
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
|
||||
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
@@ -94,6 +99,123 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
|
||||
|
||||
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
|
||||
created by ``create_empty_dataset_info()``. Using a dataclass provides
|
||||
explicit field definitions, IDE auto-completion, and validation at
|
||||
construction time.
|
||||
"""
|
||||
|
||||
codebase_version: str
|
||||
fps: int
|
||||
features: dict[str, dict]
|
||||
|
||||
# Episode / frame counters — start at zero for new datasets
|
||||
total_episodes: int = 0
|
||||
total_frames: int = 0
|
||||
total_tasks: int = 0
|
||||
|
||||
# Storage settings
|
||||
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
|
||||
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
|
||||
# File path templates
|
||||
data_path: str = field(default=DEFAULT_DATA_PATH)
|
||||
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
|
||||
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
# returns lists, but the rest of the codebase expects tuples.
|
||||
for ft in self.features.values():
|
||||
if isinstance(ft.get("shape"), list):
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
if self.chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
|
||||
if self.data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
|
||||
if self.video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatasetInfo":
|
||||
"""Construct from a raw dict (e.g. loaded directly from JSON).
|
||||
|
||||
Unknown keys are ignored for forward compatibility with datasets that
|
||||
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
|
||||
logged when such fields are present.
|
||||
"""
|
||||
known = {f.name for f in dataclasses.fields(cls)}
|
||||
unknown = sorted(k for k in data if k not in known)
|
||||
if unknown:
|
||||
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
|
||||
return cls(**{k: v for k, v in data.items() if k in known})
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Temporary dict-style compatibility layer
|
||||
# Allows existing ``info["key"]`` call-sites to keep working without changes.
|
||||
# Once all callers have been migrated to attribute access, remove these.
|
||||
# ---------------------------------------------------------------------------
|
||||
def __getitem__(self, key: str):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
|
||||
f"Use attribute access info.{key} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError as err:
|
||||
raise KeyError(key) from err
|
||||
|
||||
def __setitem__(self, key: str, value) -> None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
|
||||
f"Use attribute assignment info.{key} = ... instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not hasattr(self, key):
|
||||
raise KeyError(f"DatasetInfo has no field '{key}'")
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if a field exists (dict-like interface)."""
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
"""Get attribute value with default fallback (dict-like interface)."""
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
return default
|
||||
|
||||
|
||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||
|
||||
@@ -294,7 +416,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
dataset_info: DatasetInfo | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||
@@ -305,7 +427,7 @@ def create_lerobot_dataset_card(
|
||||
|
||||
Args:
|
||||
tags (list | None): A list of tags to add to the dataset card.
|
||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
||||
dataset_info (DatasetInfo | None): The dataset's info object, which will
|
||||
be displayed on the card.
|
||||
**kwargs: Additional keyword arguments to populate the card template.
|
||||
|
||||
@@ -318,7 +440,7 @@ def create_lerobot_dataset_card(
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
|
||||
@@ -282,7 +282,11 @@ class VideoDecoderCache:
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
return self._cache[video_path][0]
|
||||
|
||||
@@ -24,7 +24,12 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline
|
||||
from lerobot.processor import (
|
||||
IsaaclabArenaProcessorStep,
|
||||
LiberoActionProcessorStep,
|
||||
LiberoProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.utils.constants import (
|
||||
@@ -123,7 +128,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
|
||||
return {self.type: {0: vec}}
|
||||
|
||||
def get_env_processors(self):
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||
|
||||
@@ -299,7 +304,6 @@ class HILSerlProcessorConfig:
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
gripper_speed_factor: float | None = None
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@@ -437,10 +441,13 @@ class LiberoEnv(EnvConfig):
|
||||
is_libero_plus=self.is_libero_plus,
|
||||
)
|
||||
|
||||
def get_env_processors(self):
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
max_state_dim = getattr(policy_cfg, "max_state_dim", None) if getattr(policy_cfg, "type", None) == "evo1" else None
|
||||
action_feature = self.features.get(ACTION)
|
||||
action_dim = int(action_feature.shape[0]) if action_feature is not None else 7
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]),
|
||||
PolicyProcessorPipeline(steps=[LiberoActionProcessorStep(action_dim=action_dim)]),
|
||||
)
|
||||
|
||||
|
||||
@@ -706,7 +713,7 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def get_env_processors(self):
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||
if not state_keys and not camera_keys:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
@@ -52,7 +53,14 @@ def make_env_pre_post_processors(
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
return env_cfg.get_env_processors()
|
||||
get_processors = env_cfg.get_env_processors
|
||||
signature = inspect.signature(get_processors)
|
||||
supports_policy_cfg = "policy_cfg" in signature.parameters or any(
|
||||
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
|
||||
)
|
||||
if supports_policy_cfg:
|
||||
return get_processors(policy_cfg=policy_cfg)
|
||||
return get_processors()
|
||||
|
||||
|
||||
def make_env(
|
||||
|
||||
@@ -16,18 +16,16 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .gaussian_actor.reward_model.configuration_classifier import (
|
||||
RewardClassifierConfig as RewardClassifierConfig,
|
||||
)
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -35,22 +33,22 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"GaussianActorConfig",
|
||||
"Evo1Config",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
"SACConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
horizon: int = 64
|
||||
n_action_steps: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
use_separate_rgb_encoder_per_camera: bool = True
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/eo1.mdx
|
||||
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
from .modeling_eo1 import EO1Policy
|
||||
from .processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
else:
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VLTextConfig = None
|
||||
Qwen2_5_VLVisionConfig = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("eo1")
|
||||
@dataclass
|
||||
class EO1Config(PreTrainedConfig):
|
||||
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||
|
||||
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
vlm_config: dict | None = None
|
||||
|
||||
# Vision processor settings.
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
# Execution and action horizon.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 8
|
||||
n_action_steps: int = 8
|
||||
|
||||
# State/action padding to match EO1 flow head dimensionality.
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching sampling.
|
||||
num_denoise_steps: int = 10
|
||||
num_action_layers: int = 2
|
||||
action_act: str = "linear"
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
supervise_padding_action_dims: bool = True
|
||||
supervise_padding_actions: bool = True
|
||||
|
||||
# Policy-level dtype request for the Qwen backbone.
|
||||
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||
force_fp32_autocast: bool = True
|
||||
|
||||
# Optional attention backend request passed through to the Qwen backbone.
|
||||
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||
attn_implementation: str | None = None
|
||||
|
||||
# Training settings.
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.1
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# Populate the serialized backbone config only when the caller did not provide one.
|
||||
if self.vlm_config is None:
|
||||
require_package("transformers", extra="eo1")
|
||||
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||
require_package("transformers", extra="eo1")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
if self.attn_implementation is not None:
|
||||
config_dict["attn_implementation"] = self.attn_implementation
|
||||
return Qwen2_5_VLConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||
return self.vlm_backbone_config.text_config
|
||||
|
||||
@property
|
||||
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||
return self.vlm_backbone_config.vision_config
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up EO1 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"EO1 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,620 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils import torch_compilable_check
|
||||
else:
|
||||
ACT2FN = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
torch_compilable_check = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] >= new_dim:
|
||||
return vector
|
||||
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||
|
||||
|
||||
class EO1Policy(PreTrainedPolicy):
|
||||
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||
|
||||
config_class = EO1Config
|
||||
name = "eo1"
|
||||
|
||||
def __init__(self, config: EO1Config, **kwargs):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.pretrained_path is None:
|
||||
# Initialize from pretrained VLM
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_base,
|
||||
dtype=config.dtype,
|
||||
attn_implementation=config.attn_implementation,
|
||||
)
|
||||
else:
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||
)
|
||||
|
||||
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
state = self.prepare_state(batch[OBS_STATE])
|
||||
actions = self.prepare_action(batch[ACTION])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||
loss = self.model(states=state, action=actions, **model_inputs)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
states = self.prepare_state(batch[OBS_STATE])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
return actions[:, :, :original_action_dim]
|
||||
|
||||
def prepare_state(self, state: Tensor) -> Tensor:
|
||||
return pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
def prepare_action(self, action: Tensor) -> Tensor:
|
||||
return pad_vector(action, self.config.max_action_dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
activation_layer: str = "linear",
|
||||
bias: bool = True,
|
||||
device: Any = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||
layers.append(ACT2FN[activation_layer])
|
||||
in_dim = hidden_dim
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self[0].weight.dtype
|
||||
|
||||
|
||||
class EO1VisionFlowMatchingModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: EO1Config,
|
||||
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||
):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||
self.vlm_backbone = vlm_backbone
|
||||
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_out_proj = EO1VisionActionProjector(
|
||||
self.hidden_size,
|
||||
max_action_dim,
|
||||
config.num_action_layers,
|
||||
config.action_act,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm_backbone.get_input_embeddings()
|
||||
|
||||
def flow_head_autocast_context(self):
|
||||
if self.config.force_fp32_autocast:
|
||||
return torch.autocast(
|
||||
device_type=self.state_proj.weight.device.type,
|
||||
enabled=False,
|
||||
)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.vlm_backbone.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.vlm_backbone.gradient_checkpointing_disable()
|
||||
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(
|
||||
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def get_placeholder_mask(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None,
|
||||
inputs_embeds: torch.FloatTensor | None,
|
||||
state_features: torch.FloatTensor | None = None,
|
||||
action_features: torch.FloatTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||
if input_ids is None:
|
||||
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_state_mask = special_state_mask.all(-1)
|
||||
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_action_mask = special_action_mask.all(-1)
|
||||
else:
|
||||
special_state_mask = input_ids == state_token_id
|
||||
special_action_mask = input_ids == action_token_id
|
||||
|
||||
n_state_tokens = special_state_mask.sum()
|
||||
special_state_mask = (
|
||||
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if state_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||
)
|
||||
|
||||
n_action_tokens = special_action_mask.sum()
|
||||
special_action_mask = (
|
||||
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if action_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||
)
|
||||
|
||||
return special_state_mask, special_action_mask
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
states: torch.Tensor,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||
|
||||
# Get the input embeddings for the input IDs
|
||||
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
return self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||
|
||||
# Project the states to the hidden size
|
||||
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||
return self.state_proj(states)
|
||||
|
||||
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||
state_mask, _ = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_features=state_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||
return inputs_embeds
|
||||
|
||||
def embed_suffix(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
noisy_actions: torch.Tensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the suffix"""
|
||||
|
||||
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
time_embs = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.hidden_size,
|
||||
min_period=self.config.min_period,
|
||||
max_period=self.config.max_period,
|
||||
device=action_embs.device,
|
||||
)
|
||||
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||
|
||||
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||
action_time_embs = F.silu(action_time_embs)
|
||||
return self.action_time_mlp_out(action_time_embs)
|
||||
|
||||
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||
return action_time_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.FloatTensor | None = None,
|
||||
action: torch.FloatTensor | None = None,
|
||||
action_is_pad: torch.BoolTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||
|
||||
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
|
||||
# 2. Sample the diffusion target and replace the action placeholders.
|
||||
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||
u_t = noise - action
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
_, action_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
action_features=action_time_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||
|
||||
# 3. Optionally drop padded action tokens from backbone attention.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||
action_token_mask = action_mask[..., 0]
|
||||
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||
action_padding_mask = action_padding_mask.masked_scatter(
|
||||
action_token_mask,
|
||||
action_is_pad.reshape(-1),
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||
|
||||
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||
def vlm_forward_func(
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
pixel_values: torch.Tensor | None,
|
||||
image_grid_thw: torch.LongTensor | None,
|
||||
mm_token_type_ids: torch.IntTensor | None,
|
||||
) -> torch.FloatTensor:
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
hidden_states = self._apply_checkpoint(
|
||||
vlm_forward_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
mm_token_type_ids,
|
||||
)
|
||||
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||
|
||||
# 5. Project the action-token hidden states back to the flow target space.
|
||||
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
return self.action_out_proj(action_hidden_states)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# 6. Apply the configured supervision mask and reduce the loss.
|
||||
if not self.config.supervise_padding_action_dims:
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[..., :original_action_dim]
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
losses = losses[~action_is_pad]
|
||||
|
||||
return losses.mean()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.Tensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Sample actions from the model."""
|
||||
if states is None:
|
||||
raise ValueError("states are required for EO1 action sampling.")
|
||||
if mm_token_type_ids is None:
|
||||
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||
|
||||
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||
chunk_size = self.config.chunk_size
|
||||
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
).clone()
|
||||
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_mask = action_placeholder_mask[..., 0]
|
||||
token_counts = action_mask.sum(dim=1)
|
||||
if not torch.all(token_counts == chunk_size):
|
||||
raise ValueError(
|
||||
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||
)
|
||||
if action_mask.ne(action_mask[:1]).any():
|
||||
raise ValueError(
|
||||
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||
)
|
||||
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||
act_end = act_start + self.config.chunk_size
|
||||
if not torch.all(action_mask[:, act_start:act_end]):
|
||||
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||
act_slice = slice(act_start, act_end)
|
||||
|
||||
# 2. Encode the fixed prefix once and cache its KV state.
|
||||
batch_size = input_ids.shape[0]
|
||||
device = inputs_embeds.device
|
||||
attention_mask = attention_mask.to(device)
|
||||
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids[:, :act_start],
|
||||
attention_mask=attention_mask[:, :act_start],
|
||||
position_ids=position_ids[..., :act_start],
|
||||
inputs_embeds=inputs_embeds[:, :act_start],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
x_t = self.sample_noise(
|
||||
(batch_size, chunk_size, self.config.max_action_dim),
|
||||
device,
|
||||
).to(dtype=self.action_in_proj.weight.dtype)
|
||||
dt = -1.0 / self.config.num_denoise_steps
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||
for step in range(self.config.num_denoise_steps):
|
||||
time = torch.full(
|
||||
(batch_size,),
|
||||
1.0 + step * dt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||
|
||||
# Keep the prefix KV cache invariant across denoising steps.
|
||||
past_key_values.crop(act_start)
|
||||
outputs = self.vlm_backbone.model(
|
||||
attention_mask=attention_mask[:, :act_end],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds[:, act_slice],
|
||||
position_ids=position_ids[..., act_slice],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
with self.flow_head_autocast_context():
|
||||
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
v_t = self.action_out_proj(hidden_states)
|
||||
|
||||
x_t += dt * v_t.reshape(x_t.shape)
|
||||
|
||||
return x_t
|
||||
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
Qwen2_5_VLProcessor = None
|
||||
|
||||
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||
|
||||
# EO-1 special tokens
|
||||
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||
|
||||
EO1_SPECIAL_TOKENS = [
|
||||
ACTION_START_TOKEN,
|
||||
DEFAULT_ACTION_TOKEN,
|
||||
ACTION_END_TOKEN,
|
||||
STATE_START_TOKEN,
|
||||
DEFAULT_STATE_TOKEN,
|
||||
STATE_END_TOKEN,
|
||||
TASK_VLA_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||
chunk_size: int
|
||||
|
||||
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.input_features:
|
||||
first_val = next(iter(self.input_features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.input_features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.input_features = reconstructed
|
||||
|
||||
self._image_keys = [
|
||||
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||
]
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
tasks = complementary_data.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||
|
||||
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||
|
||||
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||
images = {
|
||||
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||
}
|
||||
messages = []
|
||||
for i in range(len(tasks)):
|
||||
content = [
|
||||
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||
),
|
||||
},
|
||||
]
|
||||
messages.append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||
{"role": "user", "content": content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
complementary_data["messages"] = messages
|
||||
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only materializes EO1-specific message objects in complementary_data.
|
||||
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||
feature contract change to record here.
|
||||
"""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||
},
|
||||
"chunk_size": self.chunk_size,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("transformers", extra="eo1")
|
||||
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
self.processor_name,
|
||||
use_fast=self.use_fast_processor,
|
||||
)
|
||||
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
messages = complementary_data.pop("messages", None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||
|
||||
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||
# Supervised batches use right padding to match standard training collation.
|
||||
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||
|
||||
inputs = self._processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
padding_side=padding_side,
|
||||
min_pixels=self.image_min_pixels,
|
||||
max_pixels=self.image_max_pixels,
|
||||
add_generation_prompt=False,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
complementary_data["input_ids"] = inputs["input_ids"]
|
||||
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||
complementary_data["state_token_id"] = self._state_token_id
|
||||
complementary_data["action_token_id"] = self._action_token_id
|
||||
|
||||
return complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"processor_name": self.processor_name,
|
||||
"image_min_pixels": self.image_min_pixels,
|
||||
"image_max_pixels": self.image_max_pixels,
|
||||
"use_fast_processor": self.use_fast_processor,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only converts the messages to the model input format.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_eo1_pre_post_processors(
|
||||
config: EO1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build pre/post processor pipelines for EO1."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||
EO1QwenProcessorStep(
|
||||
processor_name=config.vlm_base,
|
||||
image_min_pixels=config.image_min_pixels,
|
||||
image_max_pixels=config.image_max_pixels,
|
||||
use_fast_processor=config.use_fast_processor,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_evo1_README.md
|
||||
@@ -12,9 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .configuration_evo1 import Evo1Config
|
||||
from .modeling_evo1 import EVO1Policy
|
||||
from .processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
]
|
||||
__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"]
|
||||
@@ -0,0 +1,225 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("evo1_exact")
|
||||
@dataclass
|
||||
class Evo1SchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step: int) -> float:
|
||||
if current_step < self.num_warmup_steps:
|
||||
return current_step / max(1, self.num_warmup_steps)
|
||||
progress = (current_step - self.num_warmup_steps) / max(
|
||||
1, num_training_steps - self.num_warmup_steps
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("evo1")
|
||||
@dataclass
|
||||
class Evo1Config(PreTrainedConfig):
|
||||
training_stage: str = "stage1"
|
||||
use_amp: bool = True
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
max_state_dim: int = 24
|
||||
max_action_dim: int = 24
|
||||
max_views: int = 3
|
||||
image_resolution: tuple[int, int] = (448, 448)
|
||||
empty_cameras: int = 0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
|
||||
vlm_num_layers: int | None = 14
|
||||
vlm_dtype: str = "bfloat16"
|
||||
use_flash_attn: bool = True
|
||||
action_head: str = "flowmatching"
|
||||
embed_dim: int = 896
|
||||
hidden_dim: int = 1024
|
||||
state_hidden_dim: int = 1024
|
||||
num_heads: int = 8
|
||||
num_layers: int = 8
|
||||
dropout: float = 0.0
|
||||
num_inference_timesteps: int = 32
|
||||
num_categories: int = 1
|
||||
return_cls_only: bool = False
|
||||
enable_gradient_checkpointing: bool = True
|
||||
gradient_checkpointing_use_reentrant: bool = False
|
||||
|
||||
finetune_vlm: bool | None = None
|
||||
finetune_language_model: bool | None = None
|
||||
finetune_vision_model: bool | None = None
|
||||
finetune_action_head: bool | None = None
|
||||
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
|
||||
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
|
||||
apply_training_stage_defaults: bool = True
|
||||
|
||||
task_field: str = "task"
|
||||
embodiment_id_field: str | None = None
|
||||
default_embodiment_id: int = 0
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 300
|
||||
drop_last: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.training_stage not in {"stage1", "stage2"}:
|
||||
raise ValueError(
|
||||
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
|
||||
)
|
||||
|
||||
if self.apply_training_stage_defaults:
|
||||
if self.training_stage == "stage1":
|
||||
self.finetune_vlm = False
|
||||
self.finetune_language_model = False
|
||||
self.finetune_vision_model = False
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
self.finetune_vlm = True
|
||||
self.finetune_language_model = True
|
||||
self.finetune_vision_model = True
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage1":
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
|
||||
)
|
||||
if not has_explicit_branch_flags:
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = True
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = True
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = True
|
||||
elif self.finetune_vlm is None:
|
||||
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = False
|
||||
|
||||
branch_vlm = self.finetune_language_model or self.finetune_vision_model
|
||||
if self.finetune_vlm != branch_vlm:
|
||||
raise ValueError(
|
||||
"Inconsistent EVO1 finetune config: "
|
||||
f"finetune_vlm={self.finetune_vlm} but "
|
||||
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
|
||||
"When branch-level flags are used, finetune_vlm must match their effective union."
|
||||
)
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.input_features is None:
|
||||
self.input_features = {}
|
||||
if self.output_features is None:
|
||||
self.output_features = {}
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution),
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return Evo1SchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,234 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
|
||||
|
||||
|
||||
def _cfgget(config: Any, key: str, default=None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(key, default)
|
||||
return getattr(config, key, default)
|
||||
|
||||
|
||||
class EVO1(nn.Module):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device = _cfgget(config, "device", "cuda")
|
||||
self.return_cls_only = _cfgget(config, "return_cls_only", False)
|
||||
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
|
||||
image_size = _cfgget(config, "image_size", 448)
|
||||
if image_size is None:
|
||||
image_resolution = _cfgget(config, "image_resolution", (448, 448))
|
||||
image_size = int(image_resolution[0])
|
||||
|
||||
self.embedder = InternVL3Embedder(
|
||||
model_name=vlm_name,
|
||||
image_size=image_size,
|
||||
device=self._device,
|
||||
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
|
||||
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
|
||||
use_flash_attn=_cfgget(config, "use_flash_attn", True),
|
||||
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
|
||||
gradient_checkpointing_use_reentrant=_cfgget(
|
||||
config, "gradient_checkpointing_use_reentrant", False
|
||||
),
|
||||
)
|
||||
|
||||
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
|
||||
if action_head_type != "flowmatching":
|
||||
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
|
||||
|
||||
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
|
||||
per_action_dim = _cfgget(config, "per_action_dim", 7)
|
||||
action_dim = horizon * per_action_dim
|
||||
|
||||
if isinstance(config, dict):
|
||||
config["horizon"] = horizon
|
||||
config["per_action_dim"] = per_action_dim
|
||||
config["action_dim"] = action_dim
|
||||
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = per_action_dim
|
||||
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
|
||||
|
||||
def _normalize_image_batches(
|
||||
self,
|
||||
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
prompt: str | list[str] | None,
|
||||
image_mask: torch.Tensor,
|
||||
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
|
||||
if not images:
|
||||
raise ValueError("EVO1 expects at least one image per sample.")
|
||||
|
||||
first = images[0]
|
||||
if isinstance(first, (Image.Image, torch.Tensor)):
|
||||
image_batches = [list(images)] # type: ignore[arg-type]
|
||||
else:
|
||||
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
|
||||
|
||||
batch_size = len(image_batches)
|
||||
if prompt is None:
|
||||
prompts = [""] * batch_size
|
||||
elif isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
else:
|
||||
prompts = [str(p) for p in prompt]
|
||||
if len(prompts) != batch_size:
|
||||
raise ValueError(
|
||||
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
if image_mask.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
return image_batches, prompts, image_mask
|
||||
|
||||
def get_vl_embeddings(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str | list[str] | None = None,
|
||||
return_cls_only: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
if return_cls_only is None:
|
||||
return_cls_only = self.return_cls_only
|
||||
|
||||
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
|
||||
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
|
||||
image_tensors_batch=image_batches,
|
||||
image_masks=image_mask,
|
||||
text_prompts=prompts,
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
|
||||
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(state_input, list):
|
||||
state_tensor = torch.tensor(state_input)
|
||||
elif isinstance(state_input, torch.Tensor):
|
||||
state_tensor = state_input
|
||||
else:
|
||||
raise TypeError(f"Unsupported state input type: {type(state_input)}")
|
||||
|
||||
if state_tensor.ndim == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
return state_tensor.to(self._device)
|
||||
|
||||
def predict_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.action_head.get_action(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
return self.action_head(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str,
|
||||
state_input: list | torch.Tensor,
|
||||
return_cls_only: bool | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
|
||||
fused_tokens = self.get_vl_embeddings(
|
||||
images=[images],
|
||||
image_mask=image_mask,
|
||||
prompt=[prompt],
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
state_tensor = self.prepare_state(state_input)
|
||||
action = self.predict_action(
|
||||
fused_tokens,
|
||||
state_tensor,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
|
||||
action = action.to(torch.float32)
|
||||
return action
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
):
|
||||
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
|
||||
|
||||
def _set_module_trainable(self, module: nn.Module, trainable: bool):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = trainable
|
||||
|
||||
def set_finetune_flags(self):
|
||||
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
|
||||
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
|
||||
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (finetune_language_model, finetune_vision_model)
|
||||
)
|
||||
finetune_language_model = bool(finetune_language_model)
|
||||
finetune_vision_model = bool(finetune_vision_model)
|
||||
finetune_vlm = bool(finetune_vlm)
|
||||
|
||||
if has_explicit_branch_flags:
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
if hasattr(self.embedder.model, "language_model"):
|
||||
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
|
||||
if hasattr(self.embedder.model, "vision_model"):
|
||||
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
|
||||
if hasattr(self.embedder.model, "mlp1"):
|
||||
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
|
||||
elif not finetune_vlm:
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
|
||||
if not _cfgget(self.config, "finetune_action_head", False):
|
||||
self._set_module_trainable(self.action_head, False)
|
||||
@@ -0,0 +1,456 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfgget(config, key: str, default=None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(key, default)
|
||||
return getattr(config, key, default)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int = 1000):
|
||||
super().__init__()
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, seq_len: int):
|
||||
if seq_len > self.pe.size(1):
|
||||
self._extend_pe(seq_len)
|
||||
return self.pe[:, :seq_len, :]
|
||||
|
||||
def _extend_pe(self, new_max_len):
|
||||
old_max_len, dim = self.pe.size(1), self.pe.size(2)
|
||||
if new_max_len <= old_max_len:
|
||||
return
|
||||
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
|
||||
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
|
||||
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
|
||||
extra_pe = extra_pe.unsqueeze(0)
|
||||
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
|
||||
self.pe = new_pe
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
if num_categories <= 1:
|
||||
self.linear = nn.Linear(in_dim, out_dim)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
if self.num_categories <= 1:
|
||||
if x.dtype != self.linear.weight.dtype:
|
||||
x = x.to(dtype=self.linear.weight.dtype)
|
||||
return self.linear(x)
|
||||
|
||||
if x.dtype != self.weight.dtype:
|
||||
x = x.to(dtype=self.weight.dtype)
|
||||
|
||||
orig_shape = x.shape
|
||||
x_flat = x.reshape(-1, orig_shape[-1])
|
||||
if category_id.dim() == 0:
|
||||
cid = category_id.item()
|
||||
out = x_flat @ self.weight[cid] + self.bias[cid]
|
||||
else:
|
||||
category_id = category_id.reshape(-1)
|
||||
if category_id.numel() != x_flat.size(0):
|
||||
raise ValueError(
|
||||
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
|
||||
)
|
||||
weight_selected = self.weight[category_id]
|
||||
bias_selected = self.bias[category_id]
|
||||
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
|
||||
out_shape = orig_shape[:-1] + (out.shape[-1],)
|
||||
return out.view(out_shape)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
|
||||
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
out = self.activation(self.fc1(x, category_id))
|
||||
out = self.fc2(out, category_id)
|
||||
return out
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.embed_dim = embed_dim
|
||||
self.num_categories = num_categories
|
||||
|
||||
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
|
||||
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
|
||||
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
|
||||
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
|
||||
batch_size, horizon, action_dim = action_seq.shape
|
||||
assert self.horizon == horizon, "Action sequence length must match horizon"
|
||||
|
||||
x = action_seq.reshape(batch_size * horizon, action_dim)
|
||||
if category_id.dim() == 0:
|
||||
cat_ids = category_id.expand(horizon * batch_size)
|
||||
else:
|
||||
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
|
||||
|
||||
out = self.activation(self.W1(x, cat_ids))
|
||||
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
|
||||
out = out.view(batch_size, horizon, -1) + pos_enc
|
||||
out = out.view(batch_size * horizon, -1)
|
||||
out = self.activation(self.W2(out, cat_ids))
|
||||
out = self.W3(out, cat_ids)
|
||||
return out.view(batch_size, horizon, self.embed_dim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
|
||||
|
||||
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
|
||||
x = self.norm1(action_tokens)
|
||||
attn_out, _ = self.attn(x, context_tokens, context_tokens)
|
||||
x = action_tokens + attn_out
|
||||
x2 = self.norm2(x)
|
||||
if time_emb is not None:
|
||||
x2 = x2 + time_emb.unsqueeze(1)
|
||||
ff_out = self.ff(x2)
|
||||
return x + ff_out
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config=None,
|
||||
embed_dim: int = 896,
|
||||
hidden_dim: int = 1024,
|
||||
action_dim: int = 16 * 7,
|
||||
horizon: int = 16,
|
||||
per_action_dim: int = 7,
|
||||
num_heads: int = 8,
|
||||
num_layers: int = 8,
|
||||
dropout: float = 0.0,
|
||||
num_inference_timesteps: int = 20,
|
||||
num_categories: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is not None:
|
||||
embed_dim = _cfgget(config, "embed_dim", embed_dim)
|
||||
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
|
||||
action_dim = _cfgget(config, "action_dim", action_dim)
|
||||
horizon = _cfgget(config, "horizon", horizon)
|
||||
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
|
||||
num_heads = _cfgget(config, "num_heads", num_heads)
|
||||
num_layers = _cfgget(config, "num_layers", num_layers)
|
||||
dropout = _cfgget(config, "dropout", dropout)
|
||||
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
|
||||
num_categories = _cfgget(config, "num_categories", num_categories)
|
||||
self.config = config
|
||||
else:
|
||||
self.config = SimpleNamespace(
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
per_action_dim=per_action_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
num_inference_timesteps=num_inference_timesteps,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
|
||||
self.embed_dim = embed_dim
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
|
||||
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
|
||||
|
||||
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
hidden_dim=embed_dim * 4,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(embed_dim)
|
||||
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
|
||||
self.mlp_head = CategorySpecificMLP(
|
||||
input_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=action_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
self.state_encoder = None
|
||||
state_dim = _cfgget(self.config, "state_dim")
|
||||
if state_dim is not None:
|
||||
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
input_dim=state_dim,
|
||||
hidden_dim=state_hidden,
|
||||
output_dim=embed_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
if horizon > 1:
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.per_action_dim,
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=embed_dim,
|
||||
horizon=horizon,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
self.single_action_proj = None
|
||||
else:
|
||||
self.action_encoder = None
|
||||
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
|
||||
|
||||
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
|
||||
if self.horizon > 1 and self.action_encoder is not None:
|
||||
return self.action_encoder(action_seq, embodiment_id)
|
||||
if self.single_action_proj is None:
|
||||
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
|
||||
return self.single_action_proj(action_seq)
|
||||
|
||||
def _expand_action_mask(
|
||||
self,
|
||||
action_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
per_action_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if action_mask is None:
|
||||
raise ValueError("action_mask must be provided for flow matching inference.")
|
||||
|
||||
if action_mask.dim() == 2:
|
||||
expected_last_dim = self.horizon * per_action_dim
|
||||
if action_mask.shape == (batch_size, expected_last_dim):
|
||||
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
|
||||
elif action_mask.shape == (batch_size, per_action_dim):
|
||||
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
|
||||
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
elif action_mask.dim() == 3:
|
||||
expected_shape = (batch_size, self.horizon, per_action_dim)
|
||||
if tuple(action_mask.shape) != expected_shape:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
expanded_mask = action_mask
|
||||
else:
|
||||
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
|
||||
|
||||
return expanded_mask.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
actions_gt: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
state_mask: torch.Tensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.get_action(
|
||||
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
|
||||
)
|
||||
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
|
||||
t = (
|
||||
torch.distributions.Beta(2, 2)
|
||||
.sample((batch_size,))
|
||||
.clamp(0.02, 0.98)
|
||||
.to(device)
|
||||
.to(dtype=self.dtype)
|
||||
)
|
||||
time_index = (t * 999).long().clamp_(0, 999)
|
||||
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
|
||||
|
||||
actions_gt_seq = actions_gt
|
||||
noise = torch.rand_like(actions_gt) * 2 - 1
|
||||
if action_mask is not None:
|
||||
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
|
||||
if action_mask.shape != noise.shape:
|
||||
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
|
||||
actions_gt_seq = actions_gt_seq * action_mask
|
||||
noise = noise * action_mask
|
||||
|
||||
if self.horizon > 1:
|
||||
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
|
||||
else:
|
||||
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
|
||||
t_broadcast = t.view(batch_size, 1, 1)
|
||||
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
|
||||
|
||||
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
|
||||
target_dtype = self.dtype
|
||||
action_tokens = action_tokens.to(dtype=target_dtype)
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
|
||||
return pred_velocity, noise
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
):
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
|
||||
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
|
||||
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
|
||||
|
||||
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
action_mask = self._expand_action_mask(
|
||||
action_mask,
|
||||
batch_size=batch_size,
|
||||
per_action_dim=per_action_dim,
|
||||
device=action_seq.device,
|
||||
dtype=action_seq.dtype,
|
||||
)
|
||||
action_seq = action_seq * action_mask
|
||||
|
||||
target_dtype = self.dtype
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
|
||||
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
|
||||
if num_steps <= 0:
|
||||
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
for i in range(num_steps):
|
||||
t = i / num_steps
|
||||
time_index = min(int(t * 999), 999)
|
||||
time_emb = (
|
||||
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
|
||||
)
|
||||
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||
action = action + dt * pred
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
return action_seq.reshape(batch_size, -1)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
@@ -0,0 +1,435 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoTokenizer = None
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
|
||||
IMG_START_TOKEN = "<img>" # nosec B105
|
||||
IMG_END_TOKEN = "</img>" # nosec B105
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
|
||||
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
return
|
||||
|
||||
original_forward = encoder.forward
|
||||
|
||||
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
|
||||
original_checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
|
||||
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
|
||||
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
|
||||
|
||||
torch.utils.checkpoint.checkpoint = checkpoint
|
||||
try:
|
||||
return original_forward(*args, **kwargs)
|
||||
finally:
|
||||
torch.utils.checkpoint.checkpoint = original_checkpoint
|
||||
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
|
||||
encoder._evo1_checkpoint_patch_applied = True
|
||||
|
||||
|
||||
def flash_attn_is_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _internvl_transformers5_load_compatibility():
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
original_linspace = torch.linspace
|
||||
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
|
||||
|
||||
def linspace(*args, **kwargs):
|
||||
if kwargs.get("device") is None:
|
||||
kwargs["device"] = torch.device("cpu")
|
||||
return original_linspace(*args, **kwargs)
|
||||
|
||||
def mark_tied_weights_as_initialized(self, loading_info):
|
||||
if not hasattr(self, "all_tied_weights_keys"):
|
||||
self.all_tied_weights_keys = {}
|
||||
return original_mark_tied(self, loading_info)
|
||||
|
||||
torch.linspace = linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.linspace = original_linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=10000)
|
||||
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = orig_width * orig_height
|
||||
for ratio in target_ratios:
|
||||
target_ar = ratio[0] / ratio[1]
|
||||
diff = abs(aspect_ratio - target_ar)
|
||||
if diff < best_ratio_diff:
|
||||
best_ratio_diff = diff
|
||||
best_ratio = ratio
|
||||
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
|
||||
target_width = image_size * ratio_w
|
||||
target_height = image_size * ratio_h
|
||||
blocks = ratio_w * ratio_h
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
processed_images.append(resized_img.crop(box))
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
processed_images.append(image.resize((image_size, image_size)))
|
||||
return processed_images
|
||||
|
||||
|
||||
class InternVL3Embedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_name="OpenGVLab/InternVL3-1B",
|
||||
image_size=448,
|
||||
device="cuda",
|
||||
num_language_layers: int | None = 14,
|
||||
model_dtype: str | torch.dtype = "bfloat16",
|
||||
use_flash_attn: bool = True,
|
||||
enable_gradient_checkpointing: bool = True,
|
||||
gradient_checkpointing_use_reentrant: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._requested_device = device
|
||||
self.image_size = image_size
|
||||
self.num_language_layers = num_language_layers
|
||||
self.max_text_length = 1024
|
||||
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
|
||||
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
|
||||
|
||||
require_package("transformers", extra="evo1")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
|
||||
if isinstance(model_dtype, str):
|
||||
try:
|
||||
model_dtype = getattr(torch, model_dtype)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
|
||||
|
||||
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
|
||||
if use_flash_attn and not resolved_use_flash_attn:
|
||||
logger.warning("flash_attn is not installed. Falling back to standard attention.")
|
||||
|
||||
# InternVL3 remote code predates Transformers 5 post-init conventions:
|
||||
# it computes stochastic-depth scalars via torch.linspace(...).item()
|
||||
# while Transformers initializes under torch.device("meta"), and it
|
||||
# does not populate all_tied_weights_keys before loading finalization.
|
||||
with _internvl_transformers5_load_compatibility():
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=resolved_use_flash_attn,
|
||||
low_cpu_mem_usage=True,
|
||||
_fast_init=False,
|
||||
).to(self._requested_device)
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
layers = self.model.language_model.model.layers
|
||||
else:
|
||||
layers = self.model.language_model.layers
|
||||
if self.num_language_layers is not None:
|
||||
layers = layers[: self.num_language_layers]
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
|
||||
else:
|
||||
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||
self.model.language_model.lm_head = torch.nn.Identity()
|
||||
|
||||
self._configure_memory_features()
|
||||
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||
|
||||
def _configure_memory_features(self) -> None:
|
||||
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
|
||||
|
||||
if not self.enable_gradient_checkpointing:
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
self.model.vision_model.encoder.gradient_checkpointing = False
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||
language_model.gradient_checkpointing_disable()
|
||||
elif hasattr(language_model, "gradient_checkpointing"):
|
||||
language_model.gradient_checkpointing = False
|
||||
if hasattr(language_model, "model"):
|
||||
inner = language_model.model
|
||||
if hasattr(inner, "gradient_checkpointing_disable"):
|
||||
inner.gradient_checkpointing_disable()
|
||||
elif hasattr(inner, "gradient_checkpointing"):
|
||||
inner.gradient_checkpointing = False
|
||||
return
|
||||
|
||||
def _enable_ckpt(module: nn.Module | None) -> bool:
|
||||
if module is None:
|
||||
return False
|
||||
if hasattr(module, "gradient_checkpointing_enable"):
|
||||
try:
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
|
||||
except TypeError:
|
||||
module.gradient_checkpointing_enable()
|
||||
return True
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = True
|
||||
return True
|
||||
return False
|
||||
|
||||
enabled_any = _enable_ckpt(self.model)
|
||||
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
encoder = self.model.vision_model.encoder
|
||||
encoder.gradient_checkpointing = True
|
||||
_patch_vision_encoder_checkpointing(
|
||||
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
|
||||
)
|
||||
enabled_any = True
|
||||
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||
if hasattr(language_model, "model"):
|
||||
enabled_any = _enable_ckpt(language_model.model) or enabled_any
|
||||
if hasattr(language_model, "config"):
|
||||
language_model.config.use_cache = False
|
||||
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
if hasattr(self.model, "enable_input_require_grads"):
|
||||
self.model.enable_input_require_grads()
|
||||
|
||||
if enabled_any:
|
||||
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
|
||||
else:
|
||||
logger.warning(
|
||||
"Requested gradient checkpointing, but model does not expose checkpointing controls."
|
||||
)
|
||||
|
||||
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(image, torch.Tensor):
|
||||
pil_image = to_pil_image(image.detach().cpu())
|
||||
else:
|
||||
pil_image = image.convert("RGB")
|
||||
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
|
||||
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
|
||||
device=self.device, dtype=torch.bfloat16
|
||||
)
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
return (tile_tensors - mean) / std
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, list[list[int]]]:
|
||||
pixel_values_list = []
|
||||
batch_num_tiles_list: list[list[int]] = []
|
||||
|
||||
for image_tensors in image_tensors_batch:
|
||||
num_tiles_list: list[int] = []
|
||||
for image in image_tensors:
|
||||
tiles = self._preprocess_single_image(image)
|
||||
pixel_values_list.append(tiles)
|
||||
num_tiles_list.append(int(tiles.shape[0]))
|
||||
batch_num_tiles_list.append(num_tiles_list)
|
||||
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
pixel_values = torch.empty(
|
||||
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
return pixel_values, batch_num_tiles_list
|
||||
|
||||
def _build_multimodal_prompts(
|
||||
self,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
text_prompts: Sequence[str],
|
||||
) -> list[str]:
|
||||
prompts = []
|
||||
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
|
||||
prompt_segments = []
|
||||
for i, tile_count in enumerate(num_tiles_list):
|
||||
token_count = self.model.num_image_token * tile_count
|
||||
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
|
||||
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
|
||||
prompts.append("".join(prompt_segments) + text_prompt.strip())
|
||||
return prompts
|
||||
|
||||
def _prepare_and_fuse_embeddings(
|
||||
self,
|
||||
prompts: Sequence[str],
|
||||
vit_embeds: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
|
||||
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
|
||||
if true_sequence_length > self.max_text_length:
|
||||
logger.warning(
|
||||
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
|
||||
self.max_text_length,
|
||||
true_sequence_length,
|
||||
)
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
list(prompts),
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_text_length,
|
||||
).to(self.device)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
|
||||
img_token_mask = input_ids == self.img_context_token_id
|
||||
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
|
||||
|
||||
batch_size, _, channels = input_embeds.shape
|
||||
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
|
||||
tokens_per_tile = self.model.num_image_token
|
||||
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
|
||||
|
||||
vit_idx = 0
|
||||
for batch_index in range(batch_size):
|
||||
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
|
||||
mask_b = img_token_mask[batch_index]
|
||||
actual_vis_tokens = actual_vis_tokens_list[batch_index]
|
||||
|
||||
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
|
||||
vit_idx += expected_vis_tokens
|
||||
if actual_vis_tokens > 0:
|
||||
if item_vit_embeds.shape[0] < actual_vis_tokens:
|
||||
raise ValueError(
|
||||
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
|
||||
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
|
||||
)
|
||||
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
|
||||
|
||||
current_token_idx = 0
|
||||
img_token_locations = torch.where(mask_b)[0]
|
||||
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
|
||||
num_tokens_for_image = num_tiles * tokens_per_tile
|
||||
if not bool(image_masks[batch_index, image_index].item()):
|
||||
start_offset = current_token_idx
|
||||
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
|
||||
if start_offset < end_offset:
|
||||
idxs = img_token_locations[start_offset:end_offset]
|
||||
attention_mask[batch_index, idxs] = 0
|
||||
current_token_idx += num_tokens_for_image
|
||||
|
||||
return input_embeds, attention_mask
|
||||
|
||||
def get_fused_image_text_embedding_from_tensor_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
image_masks: torch.Tensor,
|
||||
text_prompts: Sequence[str],
|
||||
return_cls_only: bool = True,
|
||||
):
|
||||
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
|
||||
if pixel_values.shape[0] == 0:
|
||||
logger.warning("InternVL3 received an empty image batch after preprocessing.")
|
||||
hidden_size = getattr(self.model.config, "hidden_size", None)
|
||||
if hidden_size is None and hasattr(self.model.language_model, "config"):
|
||||
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
|
||||
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
|
||||
return empty
|
||||
|
||||
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
|
||||
vit_embeds = self.model.extract_feature(pixel_values)
|
||||
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
|
||||
prompts,
|
||||
vit_embeds,
|
||||
image_masks.to(device=self.device),
|
||||
batch_num_tiles_list,
|
||||
)
|
||||
|
||||
outputs = self.model.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
|
||||
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.model.parameters()).device
|
||||
@@ -0,0 +1,450 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.policies.evo1.evo1_model import EVO1
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class EVO1Policy(PreTrainedPolicy):
|
||||
config_class = Evo1Config
|
||||
name = "evo1"
|
||||
|
||||
def __init__(self, config: Evo1Config, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
if len(config.image_features) > config.max_views:
|
||||
raise ValueError(
|
||||
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model = EVO1(self._build_model_config(config))
|
||||
self.model.set_finetune_flags()
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if strict is None:
|
||||
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_model_config(config: Evo1Config) -> dict:
|
||||
return {
|
||||
"device": config.device,
|
||||
"return_cls_only": config.return_cls_only,
|
||||
"vlm_name": config.vlm_model_name,
|
||||
"vlm_num_layers": config.vlm_num_layers,
|
||||
"vlm_dtype": config.vlm_dtype,
|
||||
"use_flash_attn": config.use_flash_attn,
|
||||
"action_head": config.action_head,
|
||||
"action_horizon": config.chunk_size,
|
||||
"per_action_dim": config.max_action_dim,
|
||||
"state_dim": config.max_state_dim,
|
||||
"embed_dim": config.embed_dim,
|
||||
"hidden_dim": config.hidden_dim,
|
||||
"state_hidden_dim": config.state_hidden_dim,
|
||||
"num_heads": config.num_heads,
|
||||
"num_layers": config.num_layers,
|
||||
"dropout": config.dropout,
|
||||
"num_inference_timesteps": config.num_inference_timesteps,
|
||||
"num_categories": config.num_categories,
|
||||
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
|
||||
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
|
||||
"finetune_vlm": config.finetune_vlm,
|
||||
"finetune_language_model": config.finetune_language_model,
|
||||
"finetune_vision_model": config.finetune_vision_model,
|
||||
"finetune_action_head": config.finetune_action_head,
|
||||
}
|
||||
|
||||
@property
|
||||
def _camera_keys(self) -> list[str]:
|
||||
return list(self.config.image_features)
|
||||
|
||||
@property
|
||||
def _env_action_dim(self) -> int:
|
||||
action_feature = self.config.action_feature
|
||||
if action_feature is None:
|
||||
return self.config.max_action_dim
|
||||
return int(action_feature.shape[0])
|
||||
|
||||
@property
|
||||
def _compute_dtype(self) -> torch.dtype:
|
||||
return next(self.model.action_head.parameters()).dtype
|
||||
|
||||
@property
|
||||
def _training_compute_dtype(self) -> torch.dtype:
|
||||
if str(self.config.device).startswith("cuda"):
|
||||
return torch.bfloat16
|
||||
return self._compute_dtype
|
||||
|
||||
@property
|
||||
def _inference_compute_dtype(self) -> torch.dtype:
|
||||
if str(self.config.device).startswith("cuda") and self.config.use_amp:
|
||||
return torch.bfloat16
|
||||
return self._compute_dtype
|
||||
|
||||
def get_optim_params(self) -> list[dict]:
|
||||
decay, no_decay = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
is_bias = name.endswith("bias") or ".bias" in name
|
||||
is_norm = param.dim() == 1 or "norm" in name.lower()
|
||||
if is_bias or is_norm:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [
|
||||
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
|
||||
{"params": no_decay, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||
prompts = batch.get(self.config.task_field)
|
||||
if prompts is None and self.config.task_field != "task":
|
||||
prompts = batch.get("task")
|
||||
if prompts is None:
|
||||
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
|
||||
if isinstance(prompts, str):
|
||||
return [prompts]
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [str(prompt) for prompt in prompts]
|
||||
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if OBS_STATE not in batch:
|
||||
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
|
||||
state = batch[OBS_STATE]
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
elif state.dim() == 3:
|
||||
state = state[:, -1]
|
||||
elif state.dim() != 2:
|
||||
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
|
||||
batch_size, state_dim = state.shape
|
||||
if state_dim > self.config.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("state_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(0)
|
||||
elif explicit_mask.dim() == 3:
|
||||
explicit_mask = explicit_mask[:, -1]
|
||||
elif explicit_mask.dim() != 2:
|
||||
raise ValueError(
|
||||
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, state_dim):
|
||||
raise ValueError(
|
||||
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
|
||||
)
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=state.dtype,
|
||||
device=self.config.device,
|
||||
)
|
||||
padded[:, :state_dim] = state.to(device=self.config.device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :state_dim] = True
|
||||
else:
|
||||
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if ACTION not in batch:
|
||||
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
|
||||
action = batch[ACTION]
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1)
|
||||
batch_size, horizon, action_dim = action.shape
|
||||
if horizon != self.config.chunk_size:
|
||||
raise ValueError(
|
||||
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
|
||||
)
|
||||
if action_dim > self.config.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("action_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 2:
|
||||
if horizon == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
|
||||
)
|
||||
elif explicit_mask.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, horizon, action_dim):
|
||||
raise ValueError(
|
||||
"action_mask shape "
|
||||
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
|
||||
)
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=action.dtype,
|
||||
device=self.config.device,
|
||||
)
|
||||
padded[:, :, :action_dim] = action.to(device=self.config.device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :, :action_dim] = True
|
||||
else:
|
||||
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
mask[:, : self._env_action_dim] = True
|
||||
return mask
|
||||
|
||||
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
|
||||
embodiment_ids = batch.get("embodiment_id")
|
||||
if embodiment_ids is None and self.config.embodiment_id_field:
|
||||
embodiment_ids = batch.get(self.config.embodiment_id_field)
|
||||
if embodiment_ids is None:
|
||||
return torch.full(
|
||||
(batch_size,),
|
||||
self.config.default_embodiment_id,
|
||||
dtype=torch.long,
|
||||
device=self.config.device,
|
||||
)
|
||||
if embodiment_ids.dim() == 0:
|
||||
embodiment_ids = embodiment_ids.unsqueeze(0)
|
||||
elif embodiment_ids.dim() > 1:
|
||||
embodiment_ids = embodiment_ids[:, -1]
|
||||
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
|
||||
|
||||
@property
|
||||
def _tracks_vlm_gradients(self) -> bool:
|
||||
return bool(
|
||||
self.config.finetune_vlm
|
||||
or self.config.finetune_language_model
|
||||
or self.config.finetune_vision_model
|
||||
)
|
||||
|
||||
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
raise ValueError("EVO1 requires at least one visual observation feature.")
|
||||
|
||||
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
|
||||
# from a real batch dim and not from C in the unbatched (C, H, W) case.
|
||||
normalized: dict[str, Tensor] = {}
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
image = batch[camera_key]
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.dim() == 5:
|
||||
image = image[:, -1]
|
||||
elif image.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||
)
|
||||
normalized[camera_key] = image
|
||||
|
||||
batch_size = normalized[camera_keys[0]].shape[0]
|
||||
image_batches: list[list[Tensor]] = []
|
||||
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
sample_images: list[Tensor] = []
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
|
||||
if not sample_images:
|
||||
raise ValueError("EVO1 received a batch without any image tensor.")
|
||||
while len(sample_images) < self.config.max_views:
|
||||
sample_images.append(torch.zeros_like(sample_images[0]))
|
||||
image_batches.append(sample_images[: self.config.max_views])
|
||||
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
|
||||
|
||||
return image_batches, image_masks
|
||||
|
||||
def _compute_fused_tokens(
|
||||
self,
|
||||
prompts: list[str],
|
||||
image_batches: list[list[Tensor]],
|
||||
image_masks: Tensor,
|
||||
) -> Tensor:
|
||||
track_vlm_gradients = self._tracks_vlm_gradients
|
||||
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
|
||||
embedder = getattr(self.model, "embedder", None)
|
||||
embedder_was_training = embedder.training if embedder is not None else None
|
||||
|
||||
if not track_vlm_gradients and embedder is not None:
|
||||
embedder.eval()
|
||||
|
||||
try:
|
||||
with grad_context:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
finally:
|
||||
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
|
||||
embedder.train(embedder_was_training)
|
||||
|
||||
if not track_vlm_gradients:
|
||||
fused_tokens = fused_tokens.detach()
|
||||
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
|
||||
|
||||
def _compute_masked_loss(
|
||||
self,
|
||||
pred_velocity: Tensor,
|
||||
target_velocity: Tensor,
|
||||
action_mask: Tensor,
|
||||
reduction: str,
|
||||
) -> Tensor:
|
||||
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
|
||||
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
|
||||
active = flat_mask.sum(dim=1).clamp_min(1.0)
|
||||
per_sample_loss = sq_error.sum(dim=1) / active
|
||||
if reduction == "none":
|
||||
return per_sample_loss
|
||||
if reduction != "mean":
|
||||
raise ValueError(f"Unsupported reduction '{reduction}'")
|
||||
return sq_error.sum() / active.sum()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
actions_gt, action_mask = self._prepare_actions(batch)
|
||||
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
states = states.to(dtype=self._training_compute_dtype)
|
||||
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
|
||||
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
|
||||
pred_velocity, noise = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
|
||||
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
|
||||
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
|
||||
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
|
||||
return loss, {
|
||||
"loss": loss_mean,
|
||||
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
states = states.to(dtype=self._inference_compute_dtype)
|
||||
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
action_mask = self._prepare_inference_action_mask(states.shape[0])
|
||||
|
||||
with (
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
if self.config.use_amp and str(self.config.device).startswith("cuda")
|
||||
else nullcontext()
|
||||
):
|
||||
actions = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||
return actions[:, :, : self._env_action_dim]
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(action_chunk.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
DONE,
|
||||
INFO,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
REWARD,
|
||||
TRUNCATED,
|
||||
)
|
||||
|
||||
|
||||
def evo1_batch_to_transition(batch: dict[str, Any]):
|
||||
transition = batch_to_transition(batch)
|
||||
complementary_data = dict(transition.get("complementary_data") or {})
|
||||
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
|
||||
for key, value in batch.items():
|
||||
if key in reserved or key.startswith(OBS_PREFIX):
|
||||
continue
|
||||
complementary_data.setdefault(key, value)
|
||||
return create_transition(
|
||||
observation=transition.get("observation"),
|
||||
action=transition.get("action"),
|
||||
reward=transition.get("reward", 0.0),
|
||||
done=transition.get("done", False),
|
||||
truncated=transition.get("truncated", False),
|
||||
info=transition.get("info", {}),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
|
||||
|
||||
def make_evo1_pre_post_processors(
|
||||
config: Evo1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=evo1_batch_to_transition,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -46,14 +46,14 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
@@ -89,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "reward_classifier", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -128,22 +128,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
|
||||
return GaussianActorPolicy
|
||||
elif name == "reward_classifier":
|
||||
from .gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
return SACPolicy
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from .groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -156,6 +148,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
elif name == "eo1":
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "evo1":
|
||||
from .evo1.modeling_evo1 import EVO1Policy
|
||||
|
||||
return EVO1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -172,8 +172,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "wall_x", "eo1", "evo1".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -196,18 +196,20 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "evo1":
|
||||
return Evo1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -370,18 +372,10 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from .sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
processors = make_gaussian_actor_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from .gaussian_actor.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
processors = make_sac_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
@@ -394,14 +388,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -427,6 +413,20 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
processors = make_eo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, Evo1Config):
|
||||
from .evo1.processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
processors = make_evo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -542,7 +542,7 @@ def make_policy(
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
|
||||
peft_pretrained_path = cfg.pretrained_path
|
||||
peft_pretrained_path = str(cfg.pretrained_path)
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||
@@ -555,7 +555,9 @@ def make_policy(
|
||||
)
|
||||
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
policy = PeftModel.from_pretrained(
|
||||
policy, peft_pretrained_path, config=peft_config, is_trainable=True
|
||||
)
|
||||
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -109,7 +109,6 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
|
||||
@@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -666,8 +666,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -748,16 +747,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1292,8 +1283,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
@@ -261,13 +260,15 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
norm = 2048**0.5
|
||||
features = features / norm * norm
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -417,8 +418,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -432,8 +432,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
return fast_emb
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
@@ -666,7 +665,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if t < max_decoding_steps - 1:
|
||||
# embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -771,7 +769,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Embed the single previous token
|
||||
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
+4
-4
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
from .configuration_sac import SACConfig
|
||||
from .modeling_sac import SACPolicy
|
||||
from .processor_sac import make_sac_pre_post_processors
|
||||
|
||||
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
||||
+66
-30
@@ -75,19 +75,18 @@ class PolicyConfig:
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class GaussianActorConfig(PreTrainedConfig):
|
||||
"""Gaussian actor configuration.
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
CLI: ``--policy.type=gaussian_actor``.
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
@@ -123,7 +122,7 @@ class GaussianActorConfig(PreTrainedConfig):
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
@@ -136,41 +135,78 @@ class GaussianActorConfig(PreTrainedConfig):
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Encoder architecture
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||
online_steps: int = 1000000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
async_prefetch: bool = False
|
||||
online_step_before_learning: int = 100
|
||||
|
||||
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Network architecture
|
||||
# Actor network
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Gaussian head parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Discrete critic
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": 3e-4},
|
||||
"critic": {"lr": 3e-4},
|
||||
"temperature": {"lr": 3e-4},
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
+443
-57
@@ -15,12 +15,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -28,20 +32,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||
from .configuration_sac import SACConfig, is_image_feature
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class GaussianActorPolicy(
|
||||
class SACPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GaussianActorConfig | None = None,
|
||||
config: SACConfig | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -50,8 +54,9 @@ class GaussianActorPolicy(
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_discrete_critic()
|
||||
self._init_temperature()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
@@ -60,7 +65,11 @@ class GaussianActorPolicy(
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
@@ -70,9 +79,7 @@ class GaussianActorPolicy(
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError(
|
||||
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||
)
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -85,55 +92,360 @@ class GaussianActorPolicy(
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
if self.discrete_critic is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
else:
|
||||
discrete_action = torch.ones(
|
||||
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
|
||||
"""Actor forward pass: sample actions and return log-probabilities.
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
batch: A flat observation dict, or a training dict containing
|
||||
``"state"`` (observations) and optionally ``"observation_feature"``
|
||||
(pre-computed encoder features).
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
self.actor.load_state_dict(actor_state_dict)
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network."""
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
@@ -143,25 +455,21 @@ class GaussianActorPolicy(
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
def _init_discrete_critic(self) -> None:
|
||||
"""Initialize discrete critic network."""
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
|
||||
class GaussianActorObservationEncoder(nn.Module):
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: GaussianActorConfig) -> None:
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._init_image_layers()
|
||||
@@ -369,6 +677,84 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class DiscreteCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -414,7 +800,7 @@ class DiscreteCritic(nn.Module):
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
encoder: SACObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_min: float = -5,
|
||||
@@ -425,7 +811,7 @@ class Policy(nn.Module):
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder: GaussianActorObservationEncoder = encoder
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_min = std_min
|
||||
@@ -499,7 +885,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
@@ -545,12 +931,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
def __init__(self, config: SACConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
+5
-5
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .configuration_sac import SACConfig
|
||||
|
||||
|
||||
def make_gaussian_actor_pre_post_processors(
|
||||
config: GaussianActorConfig,
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
@@ -56,7 +56,7 @@ def make_gaussian_actor_pre_post_processors(
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the tanh-Gaussian policy.
|
||||
config: The configuration object for the SAC policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
|
||||
Returns:
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_sarm_README.md
|
||||
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: int = 20000
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -45,7 +45,7 @@ from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_flash_attn_greater_or_equal,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
|
||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -40,7 +40,7 @@ from .converters import (
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from .env_processor import IsaaclabArenaProcessorStep, LiberoActionProcessorStep, LiberoProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
make_default_robot_action_processor,
|
||||
@@ -61,7 +61,6 @@ from .hil_processor import (
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .leader_follower_processor import LeaderFollowerProcessor
|
||||
from .newline_task_processor import NewLineTaskProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
@@ -123,7 +122,6 @@ __all__ = [
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"LeaderFollowerProcessor",
|
||||
"make_default_processors",
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
@@ -151,6 +149,7 @@ __all__ = [
|
||||
"RewardProcessorStep",
|
||||
"DataProcessorPipeline",
|
||||
"IsaaclabArenaProcessorStep",
|
||||
"LiberoActionProcessorStep",
|
||||
"LiberoProcessorStep",
|
||||
"TimeLimitProcessorStep",
|
||||
"AddBatchDimensionProcessorStep",
|
||||
|
||||
@@ -38,7 +38,6 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
use_rotation: bool = False
|
||||
|
||||
def action(self, action: PolicyAction) -> RobotAction:
|
||||
if not isinstance(action, PolicyAction):
|
||||
@@ -53,13 +52,7 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"delta_y": action[1].item(),
|
||||
"delta_z": action[2].item(),
|
||||
}
|
||||
if self.use_rotation:
|
||||
delta_action["delta_wx"] = action[3].item()
|
||||
delta_action["delta_wy"] = action[4].item()
|
||||
delta_action["delta_wz"] = action[5].item()
|
||||
if self.use_gripper:
|
||||
delta_action["gripper"] = action[6].item()
|
||||
elif self.use_gripper:
|
||||
if self.use_gripper:
|
||||
delta_action["gripper"] = action[3].item()
|
||||
return delta_action
|
||||
|
||||
@@ -71,12 +64,6 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
if self.use_rotation:
|
||||
for axis in ["wx", "wy", "wz"]:
|
||||
features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
if self.use_gripper:
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
@@ -103,8 +90,6 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
use_rotation: bool = False
|
||||
rotation_scale: float = 1.0
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
@@ -112,34 +97,23 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
delta_x = action.pop("delta_x")
|
||||
delta_y = action.pop("delta_y")
|
||||
delta_z = action.pop("delta_z")
|
||||
if self.use_rotation:
|
||||
delta_wx = action.pop("delta_wx")
|
||||
delta_wy = action.pop("delta_wy")
|
||||
delta_wz = action.pop("delta_wz")
|
||||
else:
|
||||
delta_wx = 0.0
|
||||
delta_wy = 0.0
|
||||
delta_wz = 0.0
|
||||
gripper = action.pop("gripper")
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
|
||||
rotation_magnitude = (
|
||||
delta_wx**2 + delta_wy**2 + delta_wz**2
|
||||
) ** 0.5 # TODO use proper magnitud for rotation
|
||||
enabled = (
|
||||
position_magnitude > self.noise_threshold or rotation_magnitude > self.noise_threshold
|
||||
) # Small threshold to avoid noise
|
||||
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = delta_x * self.position_scale
|
||||
scaled_delta_y = delta_y * self.position_scale
|
||||
scaled_delta_z = delta_z * self.position_scale
|
||||
|
||||
target_wx = delta_wx * self.rotation_scale
|
||||
target_wy = delta_wy * self.rotation_scale
|
||||
target_wz = delta_wz * self.rotation_scale
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
@@ -158,15 +132,9 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
for axis in ["x", "y", "z"]:
|
||||
for axis in ["x", "y", "z", "gripper"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
|
||||
|
||||
if self.use_rotation:
|
||||
for axis in ["wx", "wy", "wz"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
|
||||
|
||||
features[PipelineFeatureType.ACTION].pop("delta_gripper", None)
|
||||
|
||||
for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
|
||||
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
|
||||
@@ -18,9 +18,9 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -46,6 +46,8 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||
"""
|
||||
|
||||
max_state_dim: int | None = None
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
@@ -78,6 +80,15 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
if self.max_state_dim is not None:
|
||||
if state.shape[-1] > self.max_state_dim:
|
||||
raise ValueError(
|
||||
f"LIBERO state has {state.shape[-1]} dims, which is larger than "
|
||||
f"configured max_state_dim={self.max_state_dim}."
|
||||
)
|
||||
if state.shape[-1] < self.max_state_dim:
|
||||
pad_width = self.max_state_dim - state.shape[-1]
|
||||
state = torch.nn.functional.pad(state, (0, pad_width))
|
||||
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
@@ -101,7 +112,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
# add our new flattened state
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
shape=(self.max_state_dim or 8,), # [eef_pos(3), axis_angle(3), gripper(2)] plus padding
|
||||
)
|
||||
|
||||
new_features[FeatureType.STATE] = state_feats
|
||||
@@ -111,6 +122,9 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return {"max_state_dim": self.max_state_dim}
|
||||
|
||||
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert batched quaternions to axis-angle format.
|
||||
@@ -153,6 +167,32 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="libero_action_processor")
|
||||
class LiberoActionProcessorStep(ActionProcessorStep):
|
||||
"""Slices padded policy actions back to the executable LIBERO action space."""
|
||||
|
||||
action_dim: int = 7
|
||||
|
||||
def action(self, action):
|
||||
if action.shape[-1] < self.action_dim:
|
||||
raise ValueError(
|
||||
f"LIBERO action has {action.shape[-1]} dims, which is smaller than action_dim={self.action_dim}."
|
||||
)
|
||||
return action[..., : self.action_dim]
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
new_features = {ft: feats.copy() for ft, feats in features.items()}
|
||||
action_feats = new_features.setdefault(FeatureType.ACTION, {})
|
||||
action_feats[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
||||
return new_features
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return {"action_dim": self.action_dim}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -321,7 +321,6 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -331,9 +330,6 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if DISCRETE_PENALTY_KEY in info:
|
||||
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
@@ -352,24 +348,18 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a small per-transition cost on the discrete gripper action.
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
Fires only when the commanded action would actually transition the gripper
|
||||
from one extreme to the other (close-while-open or open-while-closed).
|
||||
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||
commands unpenalized.
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
open_threshold: Normalized state below which the gripper is considered "open".
|
||||
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||
"""
|
||||
|
||||
penalty: float = -0.02
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
open_threshold: float = 0.1
|
||||
closed_threshold: float = 0.9
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -401,13 +391,9 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
# - currently open AND target is closed -> close transition
|
||||
# - currently closed AND target is open -> open transition
|
||||
is_open = gripper_state_normalized < self.open_threshold
|
||||
is_closed = gripper_state_normalized > self.closed_threshold
|
||||
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||
cmd_open = gripper_action_normalized < self.open_threshold
|
||||
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
@@ -423,14 +409,11 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value, max gripper position,
|
||||
and the open/closed thresholds.
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
"open_threshold": self.open_threshold,
|
||||
"closed_threshold": self.closed_threshold,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -461,7 +444,6 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
use_rotation: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -498,14 +480,6 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
teleop_action.get("delta_y", 0.0),
|
||||
teleop_action.get("delta_z", 0.0),
|
||||
]
|
||||
if self.use_rotation:
|
||||
action_list.extend(
|
||||
[
|
||||
teleop_action.get("delta_wx", 0.0),
|
||||
teleop_action.get("delta_wy", 0.0),
|
||||
teleop_action.get("delta_wz", 0.0),
|
||||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
@@ -583,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
def __post_init__(self):
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
from .pipeline import ProcessorStep
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("leader_follower_processor")
|
||||
@dataclass
|
||||
class LeaderFollowerProcessor(ProcessorStep):
|
||||
"""
|
||||
Processor for leader-follower teleoperation mode.
|
||||
|
||||
This processor:
|
||||
1. Sends follower positions to leader arm when not intervening
|
||||
2. Computes EE delta actions from leader when intervening
|
||||
3. Handles teleop events from the leader device
|
||||
"""
|
||||
|
||||
leader_device: Teleoperator
|
||||
motor_names: list[str]
|
||||
robot: Robot
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: np.ndarray | None = None
|
||||
use_gripper: bool = True
|
||||
# prev_leader_gripper: float | None = None
|
||||
max_gripper_pos: float = 100.0
|
||||
use_ik_solution: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process transition with leader-follower logic."""
|
||||
# Get current follower position from complementary data
|
||||
# raw_joint_pos = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("raw_joint_positions")
|
||||
raw_joint_pos = transition.get(TransitionKey.OBSERVATION)
|
||||
if raw_joint_pos is not None:
|
||||
# Send follower position to leader (for follow mode)
|
||||
# follower_action = {
|
||||
# f"{motor}.pos": float(raw_joint_pos[motor])
|
||||
# for motor in self.motor_names
|
||||
# }
|
||||
self.leader_device.send_action(raw_joint_pos)
|
||||
|
||||
# Only compute EE action if intervention is active
|
||||
# (AddTeleopEventsAsInfo already added IS_INTERVENTION to info)
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
if info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
# Get leader joint positions from teleop_action
|
||||
# (AddTeleopActionAsComplimentaryData already got the action)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary.get("teleop_action", {})
|
||||
|
||||
if isinstance(teleop_action, dict) and raw_joint_pos is not None:
|
||||
leader_pos = np.array([teleop_action[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
leader_ee = self.kinematics.forward_kinematics(leader_pos)
|
||||
|
||||
if self.use_ik_solution and "IK_solution" in transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
follower_pos = transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
else:
|
||||
follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
follower_ee = self.kinematics.forward_kinematics(follower_pos)
|
||||
|
||||
# follower_gripper_pos = raw_joint_pos["gripper.pos"]
|
||||
follower_gripper_pos = follower_pos[-1] # assuming gripper is the last motor
|
||||
|
||||
leader_ee_pos = leader_ee[:3, 3]
|
||||
leader_ee_rvec = Rotation.from_matrix(leader_ee[:3, :3]).as_rotvec()
|
||||
leader_gripper_pos = np.clip(
|
||||
teleop_action["gripper.pos"], -self.max_gripper_pos, self.max_gripper_pos
|
||||
)
|
||||
|
||||
follower_ee_pos = follower_ee[:3, 3]
|
||||
# follower_ee_rvec = Rotation.from_matrix(follower_ee[:3, :3]).as_rotvec()
|
||||
|
||||
delta_pos = leader_ee_pos - follower_ee_pos
|
||||
|
||||
# For rotation: compute relative rotation from follower to leader
|
||||
# R_leader = R_follower * R_delta => R_delta = R_follower^T * R_leader
|
||||
r_delta = follower_ee[:3, :3].T @ leader_ee[:3, :3]
|
||||
delta_rvec = Rotation.from_matrix(r_delta).as_rotvec()
|
||||
|
||||
delta_gripper = leader_gripper_pos - follower_gripper_pos
|
||||
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = follower_ee[:3, :3] @ r_delta
|
||||
desired[:3, 3] = follower_ee[:3, 3] + delta_pos
|
||||
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
|
||||
assert np.allclose(pos, leader_ee_pos), "Position delta computation error"
|
||||
assert np.allclose(tw, leader_ee_rvec), "Orientation delta computation error"
|
||||
assert np.isclose(follower_gripper_pos + delta_gripper, leader_gripper_pos), (
|
||||
"Gripper delta computation error"
|
||||
)
|
||||
|
||||
# Normalize the action to the range [-1, 1]
|
||||
delta_pos = delta_pos / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["x"],
|
||||
self.end_effector_step_sizes["y"],
|
||||
self.end_effector_step_sizes["z"],
|
||||
]
|
||||
)
|
||||
delta_rvec = delta_rvec / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["wx"],
|
||||
self.end_effector_step_sizes["wy"],
|
||||
self.end_effector_step_sizes["wz"],
|
||||
]
|
||||
)
|
||||
max_normalized_pos = max(
|
||||
abs(delta_pos[0]),
|
||||
abs(delta_pos[1]),
|
||||
abs(delta_pos[2]),
|
||||
)
|
||||
|
||||
normalized_rot = max(abs(delta_rvec[0]), abs(delta_rvec[1]), abs(delta_rvec[2]))
|
||||
|
||||
max_normalized = max(max_normalized_pos, normalized_rot)
|
||||
|
||||
if max_normalized > 1.0:
|
||||
# Scale proportionally
|
||||
delta_pos = delta_pos / max_normalized
|
||||
delta_rvec = delta_rvec / max_normalized
|
||||
|
||||
intervention_action = np.array(
|
||||
[
|
||||
delta_pos[0],
|
||||
delta_pos[1],
|
||||
delta_pos[2],
|
||||
delta_rvec[0],
|
||||
delta_rvec[1],
|
||||
delta_rvec[2],
|
||||
np.clip(delta_gripper, -self.max_gripper_pos, self.max_gripper_pos)
|
||||
/ self.max_gripper_pos,
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
# # Extract leader positions from teleop action dict
|
||||
# # leader_pos = np.array([teleop_action.get(f"{motor}.pos", 0) for motor in self.motor_names])
|
||||
# # follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
# teleop_action = self.leader_device.bus.sync_read("Present_Position")
|
||||
# raw_joint_pos = self.robot.bus.sync_read("Present_Position")
|
||||
# leader_pos = np.array([teleop_action.get(f"{motor}", 0) for motor in self.motor_names])
|
||||
# follower_pos = np.array([raw_joint_pos[f"{motor}"] for motor in self.motor_names])
|
||||
|
||||
# # Compute EE positions
|
||||
# leader_ee_fi = self.kinematics.forward_kinematics(leader_pos)
|
||||
# leader_ee_pos = leader_ee_fi[:3, 3]
|
||||
# # leader_ee_rot = Rotation.from_matrix(leader_ee_fi[:3, :3]).as_rotvec()
|
||||
# leader_ee = np.concat([leader_ee_pos, [0,0,0]])
|
||||
|
||||
# if "IK_solution" in transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
# follower_ee = transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
# else:
|
||||
# follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
# follower_ee_fi = self.kinematics.forward_kinematics(follower_pos)
|
||||
# follower_ee_pos = follower_ee_fi[:3, 3]
|
||||
# # follower_ee_rot = Rotation.from_matrix(follower_ee_fi[:3, :3]).as_rotvec()
|
||||
# follower_ee = np.concat([follower_ee_pos, [0,0,0]])
|
||||
|
||||
# # Compute normalized EE delta
|
||||
# if self.end_effector_step_sizes is not None:
|
||||
# ee_delta = np.clip(
|
||||
# leader_ee - follower_ee,
|
||||
# -self.end_effector_step_sizes,
|
||||
# self.end_effector_step_sizes
|
||||
# )
|
||||
# ee_delta_normalized = ee_delta / self.end_effector_step_sizes
|
||||
# else:
|
||||
# ee_delta_normalized = leader_ee - follower_ee
|
||||
|
||||
# # Handle gripper
|
||||
# if self.use_gripper and len(leader_pos) > 3:
|
||||
# if self.prev_leader_gripper is None:
|
||||
# self.prev_leader_gripper = np.clip(
|
||||
# leader_pos[-1], 0, self.max_gripper_pos
|
||||
# )
|
||||
|
||||
# leader_gripper = leader_pos[-1]
|
||||
# gripper_delta = leader_gripper - self.prev_leader_gripper
|
||||
# normalized_delta = gripper_delta / self.max_gripper_pos
|
||||
|
||||
# # Quantize gripper action
|
||||
# if normalized_delta >= 0.3:
|
||||
# gripper_action = 2
|
||||
# elif normalized_delta <= -0.1:
|
||||
# gripper_action = 0
|
||||
# else:
|
||||
# gripper_action = 1
|
||||
|
||||
# self.prev_leader_gripper = leader_gripper
|
||||
|
||||
# # Create intervention action
|
||||
# intervention_action = np.append(ee_delta_normalized, gripper_action)
|
||||
# else:
|
||||
# intervention_action = ee_delta_normalized
|
||||
|
||||
# # Override teleop_action with computed EE action
|
||||
complementary["teleop_action"] = torch.from_numpy(intervention_action).float()
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary # type: ignore[misc]
|
||||
|
||||
return transition
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset leader-follower state."""
|
||||
# self.prev_leader_gripper = None
|
||||
if hasattr(self.leader_device, "reset"):
|
||||
self.leader_device.reset()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -134,15 +134,6 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -161,7 +152,6 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -211,7 +201,6 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -222,7 +211,6 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
make_reward_model_config as make_reward_model_config,
|
||||
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
"get_reward_model_class",
|
||||
"make_reward_model",
|
||||
"make_reward_model_config",
|
||||
"make_reward_pre_post_processors",
|
||||
]
|
||||
+5
-6
@@ -1,5 +1,3 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -15,14 +13,15 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import NormalizationMode, PreTrainedConfig
|
||||
from lerobot.configs import NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@RewardModelConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
class RewardClassifierConfig(RewardModelConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
@@ -31,7 +30,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "lerobot/resnet10"
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
+13
-32
@@ -1,5 +1,3 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -19,11 +17,10 @@ import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
from ...pretrained import PreTrainedPolicy
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
@@ -99,7 +96,7 @@ class SpatialLearnedEmbeddings(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedPolicy):
|
||||
class Classifier(PreTrainedRewardModel):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
@@ -108,7 +105,6 @@ class Classifier(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
**kwargs,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -236,6 +232,16 @@ class Classifier(PreTrainedPolicy):
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
output = self.predict(images)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
return (output.probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(output.probabilities, dim=1).float()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Extract images and labels
|
||||
@@ -279,28 +285,3 @@ class Classifier(PreTrainedPolicy):
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not produce action chunks.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not predict action chunks")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
+1
-6
@@ -1,5 +1,3 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -27,8 +25,7 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
@@ -52,8 +49,6 @@ def make_classifier_processor(
|
||||
Args:
|
||||
config: The configuration object for the RewardClassifier.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
"""
|
||||
Retrieves a reward model class by its registered name.
|
||||
|
||||
This function uses dynamic imports to avoid loading all reward model classes into
|
||||
memory at once, improving startup time and reducing dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the reward model name is not recognized.
|
||||
"""
|
||||
if name == "reward_classifier":
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "sarm":
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Reward model type '{name}' is not available.") from e
|
||||
|
||||
|
||||
def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
"""
|
||||
Instantiates a reward model configuration object based on the reward type.
|
||||
|
||||
This factory function simplifies the creation of reward model configuration objects
|
||||
by mapping a string identifier to the corresponding config class.
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of a `RewardModelConfig` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `reward_type` is not recognized.
|
||||
"""
|
||||
if reward_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
return config_cls(**kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Reward model type '{reward_type}' is not available.") from e
|
||||
|
||||
|
||||
def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel:
|
||||
"""
|
||||
Instantiate a reward model from its configuration.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the reward model to be created. If
|
||||
`cfg.pretrained_path` is set, the model will be loaded with weights
|
||||
from that path.
|
||||
**kwargs: Additional keyword arguments forwarded to the model constructor
|
||||
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||
|
||||
Returns:
|
||||
An instantiated and device-placed reward model.
|
||||
"""
|
||||
reward_cls = get_reward_model_class(cfg.type)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
reward_model.to(cfg.device)
|
||||
assert isinstance(reward_model, torch.nn.Module)
|
||||
|
||||
return reward_model
|
||||
|
||||
|
||||
def make_reward_pre_post_processors(
|
||||
reward_cfg: RewardModelConfig,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Create pre- and post-processor pipelines for a given reward model.
|
||||
|
||||
Each reward model type has a dedicated factory function for its processors.
|
||||
|
||||
Args:
|
||||
reward_cfg: The configuration of the reward model for which to create processors.
|
||||
**kwargs: Additional keyword arguments passed to the processor factory
|
||||
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
|
||||
Raises:
|
||||
ValueError: If a processor factory is not implemented for the given reward
|
||||
model configuration type.
|
||||
"""
|
||||
# Create a new processor based on reward model type
|
||||
if isinstance(reward_cfg, RewardClassifierConfig):
|
||||
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
|
||||
|
||||
return make_classifier_processor(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, SARMConfig):
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
return make_sarm_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Processor for reward model type '{reward_cfg.type}' is not implemented."
|
||||
) from e
|
||||
return processors
|
||||
|
||||
|
||||
def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]:
|
||||
"""Get reward model class from its registered name using dynamic imports.
|
||||
|
||||
This is used as a helper function to import reward models from 3rd party lerobot
|
||||
plugins.
|
||||
|
||||
Args:
|
||||
name: The name of the reward model.
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
"""
|
||||
if name not in RewardModelConfig.get_known_choices():
|
||||
raise ValueError(
|
||||
f"Unknown reward model name '{name}'. "
|
||||
f"Available reward models: {RewardModelConfig.get_known_choices()}"
|
||||
)
|
||||
|
||||
config_cls = RewardModelConfig.get_choice_class(name)
|
||||
config_cls_name = config_cls.__name__
|
||||
|
||||
model_name = config_cls_name.removesuffix("Config")
|
||||
if model_name == config_cls_name:
|
||||
raise ValueError(
|
||||
f"The config class name '{config_cls_name}' does not follow the expected naming convention. "
|
||||
f"Make sure it ends with 'Config'!"
|
||||
)
|
||||
|
||||
cls_name = model_name + "RewardModel"
|
||||
module_path = config_cls.__module__.replace("configuration_", "modeling_")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
reward_cls = getattr(module, cls_name)
|
||||
return reward_cls
|
||||
|
||||
|
||||
def _make_processors_from_reward_model_config(
|
||||
config: RewardModelConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Create pre- and post-processors from a reward model configuration using dynamic imports.
|
||||
|
||||
This is used as a helper function to import processor factories from 3rd party
|
||||
lerobot reward model plugins.
|
||||
|
||||
Args:
|
||||
config: The reward model configuration object.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
"""
|
||||
reward_type = config.type
|
||||
function_name = f"make_{reward_type}_pre_post_processors"
|
||||
module_path = config.__class__.__module__.replace("configuration_", "processor_")
|
||||
logging.debug(
|
||||
f"Instantiating reward pre/post processors using function '{function_name}' "
|
||||
f"from module '{module_path}'"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
return function(config, dataset_stats=dataset_stats)
|
||||
@@ -0,0 +1,244 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedRewardModel")
|
||||
|
||||
|
||||
class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC):
|
||||
"""Base class for reward models."""
|
||||
|
||||
config_class: None
|
||||
name: None
|
||||
|
||||
def __init__(self, config: RewardModelConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, RewardModelConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`RewardModelConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are
|
||||
deactivated). To train it, you should first set it back in training mode with `reward.train()`.
|
||||
"""
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
reward.to(config.device)
|
||||
reward.eval()
|
||||
return reward
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
# Create base kwargs
|
||||
kwargs = {"strict": strict}
|
||||
|
||||
# Add device parameter for newer versions that support it
|
||||
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
||||
kwargs["device"] = map_location
|
||||
|
||||
# Load the model with appropriate kwargs
|
||||
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
||||
if missing_keys:
|
||||
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||
|
||||
# For older versions, manually move to device if needed
|
||||
if "device" not in kwargs and map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self):
|
||||
"""
|
||||
Returns the reward-model-specific parameters dict to be passed on to the optimizer.
|
||||
"""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset any internal state."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute a scalar reward signal for a batch of observations.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing at minimum observation tensors.
|
||||
May also contain "action", "next_observation.*", etc.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(batch_size,)`` with reward values.
|
||||
"""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — override for trainable reward models."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
"""Whether this reward model can be trained via ``lerobot-train``.
|
||||
|
||||
Trainable reward models override :meth:`forward`; zero-shot models
|
||||
inherit the base implementation that raises ``NotImplementedError``.
|
||||
"""
|
||||
return type(self).forward is not PreTrainedRewardModel.forward
|
||||
|
||||
def push_model_to_hub(self, cfg: "TrainPipelineConfig"):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
# Push the files to the repo in a single commit
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload reward model weights, train config and readme",
|
||||
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log"],
|
||||
)
|
||||
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
library_name="lerobot",
|
||||
pipeline_tag="robotics",
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
)
|
||||
|
||||
template_card = (
|
||||
files("lerobot.templates")
|
||||
.joinpath("lerobot_rewardmodel_modelcard_template.md")
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,5 +14,6 @@
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
__all__ = ["SARMConfig", "SARMRewardModel"]
|
||||
__all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]
|
||||
+8
-9
@@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||
|
||||
Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
@@ -58,10 +58,9 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
from .sarm_utils import normalize_stage_tau
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
@@ -713,12 +712,12 @@ def main():
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
+4
-6
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -22,14 +20,15 @@ Paper: https://arxiv.org/abs/2509.25358
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@RewardModelConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
class SARMConfig(RewardModelConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
Supports three annotation modes:
|
||||
@@ -110,7 +109,6 @@ class SARMConfig(PreTrainedConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||
raise ValueError(
|
||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||
+23
-17
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -34,14 +32,13 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
@@ -353,7 +350,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||
return stage_onehot
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
class SARMRewardModel(PreTrainedRewardModel):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
@@ -471,6 +468,23 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
self.subtask_model.to(device)
|
||||
return self
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute dense progress reward in [0, 1] from batch.
|
||||
|
||||
Expects batch to contain:
|
||||
- "observation_features" or video embeddings: (B, T, 512)
|
||||
- "language_embedding" or text embeddings: (B, 512)
|
||||
- optionally "observation.state": (B, T, state_dim)
|
||||
"""
|
||||
text_emb = batch.get("language_embedding", batch.get("text_features"))
|
||||
video_emb = batch.get("observation_features", batch.get("video_features"))
|
||||
state = batch.get("observation.state", batch.get("state_features"))
|
||||
|
||||
rewards = self.calculate_rewards(text_emb, video_emb, state)
|
||||
if isinstance(rewards, np.ndarray):
|
||||
rewards = torch.from_numpy(rewards).float()
|
||||
return rewards
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
@@ -631,17 +645,9 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
"""SARM has no episode-level state to reset."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _train_step(
|
||||
self,
|
||||
img_emb: torch.Tensor, # (B, N, T, D)
|
||||
+4
-7
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -60,16 +58,15 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,14 +12,38 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
|
||||
|
||||
This module implements the SampleWeighter protocol for RA-BC training,
|
||||
which weights training samples based on their task progress as measured
|
||||
by the SARM reward model.
|
||||
|
||||
The weights are computed based on progress deltas:
|
||||
delta = progress[t + chunk_size] - progress[t]
|
||||
|
||||
High-quality samples (positive progress) get higher weights, while
|
||||
samples with negative progress (going backwards) get zero weight.
|
||||
|
||||
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.utils.import_utils import _pandas_available
|
||||
from lerobot.utils.sample_weighting import SampleWeighter
|
||||
|
||||
if TYPE_CHECKING or _pandas_available:
|
||||
import pandas as pd
|
||||
else:
|
||||
pd = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def resolve_hf_path(path: str | Path) -> Path:
|
||||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||||
@@ -34,23 +56,27 @@ def resolve_hf_path(path: str | Path) -> Path:
|
||||
return Path(path)
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
class RABCWeights(SampleWeighter):
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
This class implements the SampleWeighter ABC for use with the generic
|
||||
sample weighting infrastructure in lerobot.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values
|
||||
chunk_size: Number of frames ahead for computing progress delta
|
||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||
device: Device to return tensors on
|
||||
progress_path: Path to parquet file with precomputed progress values.
|
||||
Supports HuggingFace URLs (hf://datasets/...).
|
||||
chunk_size: Number of frames ahead for computing progress delta.
|
||||
head_mode: Which SARM head to use ("sparse" or "dense").
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01).
|
||||
epsilon: Small constant for numerical stability (default: 1e-6).
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
|
||||
device: Device to return tensors on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -61,7 +87,7 @@ class RABCWeights:
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
self.progress_path = resolve_hf_path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
@@ -87,8 +113,8 @@ class RABCWeights:
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup = {}
|
||||
self.episode_lookup = {}
|
||||
self.progress_lookup: dict[int, float] = {}
|
||||
self.episode_lookup: dict[int, int] = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
@@ -100,7 +126,7 @@ class RABCWeights:
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries = {}
|
||||
self.episode_boundaries: dict[int, dict[str, int]] = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
@@ -114,7 +140,7 @@ class RABCWeights:
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self):
|
||||
def _compute_global_stats(self) -> None:
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
@@ -138,8 +164,8 @@ class RABCWeights:
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
|
||||
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
@@ -157,18 +183,19 @@ class RABCWeights:
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices
|
||||
batch: Training batch containing "index" key with global frame indices.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size.
|
||||
- Stats dict with weighting statistics for logging.
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
|
||||
return torch.ones(batch_size, device=self.device), stats
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
@@ -183,29 +210,29 @@ class RABCWeights:
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas = np.array(deltas, dtype=np.float32)
|
||||
deltas_array = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas)
|
||||
weights = self._compute_weights(deltas_array)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"raw_mean_weight": raw_mean_weight,
|
||||
"mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
batch_size = len(weights_tensor)
|
||||
weight_sum = weights_tensor.sum() + self.epsilon
|
||||
weights_tensor = weights_tensor * batch_size / weight_sum
|
||||
|
||||
return weights, batch_stats
|
||||
return weights_tensor, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
@@ -241,7 +268,7 @@ class RABCWeights:
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights
|
||||
Array of weights.
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
@@ -273,12 +300,13 @@ class RABCWeights:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return val.shape[0]
|
||||
return int(val.shape[0])
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics."""
|
||||
"""Get global statistics about the RA-BC weighting."""
|
||||
return {
|
||||
"type": "rabc",
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
+16
-26
@@ -12,33 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Reinforcement learning modules.
|
||||
"""
|
||||
Reinforcement learning modules.
|
||||
|
||||
Distributed actor / learner entry points (``actor``, ``learner``,
|
||||
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
|
||||
buffer, data sources and trainer are gRPC-free and usable standalone.
|
||||
Requires: ``pip install 'lerobot[hilserl]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.rl.actor import ...
|
||||
from lerobot.rl.learner import ...
|
||||
from lerobot.rl.learner_service import ...
|
||||
from lerobot.rl.buffer import ...
|
||||
from lerobot.rl.eval_policy import ...
|
||||
from lerobot.rl.gym_manipulator import ...
|
||||
"""
|
||||
|
||||
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||
from .algorithms.factory import (
|
||||
make_algorithm as make_algorithm,
|
||||
make_algorithm_config as make_algorithm_config,
|
||||
)
|
||||
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .buffer import ReplayBuffer as ReplayBuffer
|
||||
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||
from .trainer import RLTrainer as RLTrainer
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"make_algorithm",
|
||||
"make_algorithm_config",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTrainer",
|
||||
"ReplayBuffer",
|
||||
"DataMixer",
|
||||
"OnlineOfflineMixer",
|
||||
]
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
+49
-44
@@ -51,19 +51,17 @@ import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
@@ -76,12 +74,14 @@ from lerobot.transport.utils import (
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
@@ -90,11 +90,12 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
from .gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
reset_and_build_transition,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .queue import get_last_item_from_queue
|
||||
|
||||
# Main entry point
|
||||
|
||||
@@ -211,7 +212,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
shutdown_event: Any, # Event
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
@@ -251,21 +252,22 @@ def act_with_policy(
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a policy object through the port, we create a policy instance
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy = make_policy(
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.to(device).eval()
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -289,17 +291,8 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# Unnormalize only the continuous part.
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||
discrete_action = action[..., -1:].to(
|
||||
device=continuous_action.device, dtype=continuous_action.dtype
|
||||
)
|
||||
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||
else:
|
||||
action = postprocessor.process_action(action)
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -333,8 +326,7 @@ def act_with_policy(
|
||||
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
|
||||
if is_intervention:
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
episode_intervention = True
|
||||
episode_intervention_steps += 1
|
||||
|
||||
@@ -342,7 +334,6 @@ def act_with_policy(
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
@@ -399,7 +390,14 @@ def act_with_policy(
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
@@ -411,7 +409,7 @@ def act_with_policy(
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Any, # Event
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
@@ -463,7 +461,7 @@ def learner_service_client(
|
||||
def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Any, # Event
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
@@ -515,7 +513,7 @@ def receive_policy(
|
||||
def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: Any, # Event
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
@@ -565,7 +563,7 @@ def send_transitions(
|
||||
def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Any, # Event
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
@@ -615,11 +613,7 @@ def send_interactions(
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(
|
||||
shutdown_event: Any, # Event
|
||||
transitions_queue: Queue,
|
||||
timeout: float,
|
||||
) -> services_pb2.Empty:
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
@@ -635,9 +629,9 @@ def transitions_stream(
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Any, # Event
|
||||
shutdown_event: Event,
|
||||
interactions_queue: Queue,
|
||||
timeout: float,
|
||||
timeout: float, # type: ignore
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
@@ -658,7 +652,7 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
@@ -673,7 +667,18 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
policy.load_actor_weights(state_dicts, device=device)
|
||||
|
||||
# Load actor state dict
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
class RLAlgorithm(abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
config_class: type[RLAlgorithmConfig] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One complete training step.
|
||||
|
||||
The algorithm calls ``next(batch_iterator)`` as many times as it
|
||||
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
...
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
) -> Iterator[BatchType]:
|
||||
"""Create the data iterator this algorithm needs.
|
||||
|
||||
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
||||
Algorithms that need specialised sampling should override this method.
|
||||
"""
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""Create, store, and return the optimizers needed for training.
|
||||
|
||||
Called on the **learner** side after construction. Subclasses must
|
||||
override this with algorithm-specific optimizer setup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def optimization_step(self) -> int:
|
||||
"""Current learner optimization step.
|
||||
|
||||
Part of the stable contract for checkpoint/resume. Algorithms can
|
||||
either use this default storage or override for custom behavior.
|
||||
"""
|
||||
return getattr(self, "_optimization_step", 0)
|
||||
|
||||
@optimization_step.setter
|
||||
def optimization_step(self, value: int) -> None:
|
||||
self._optimization_step = int(value)
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors."""
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
@@ -1,76 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingStats:
|
||||
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
|
||||
|
||||
losses: dict[str, float] = field(default_factory=dict)
|
||||
grad_norms: dict[str, float] = field(default_factory=dict)
|
||||
extra: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_log_dict(self) -> dict[str, float]:
|
||||
"""Flatten all stats into a single dict for logging."""
|
||||
|
||||
d: dict[str, float] = {}
|
||||
for name, val in self.losses.items():
|
||||
d[name] = val
|
||||
for name, val in self.grad_norms.items():
|
||||
d[f"{name}_grad_norm"] = val
|
||||
for name, val in self.extra.items():
|
||||
d[name] = val
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Registered name of this algorithm config (e.g. ``"sac"``)."""
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
"""Construct the :class:`RLAlgorithm` for this config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||
"""Build an algorithm config from a policy config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
||||
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
|
||||
|
||||
Args:
|
||||
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
||||
**kwargs: Keyword arguments forwarded to the config class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the matching ``RLAlgorithmConfig`` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``algorithm_type`` is not registered.
|
||||
"""
|
||||
try:
|
||||
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"Algorithm type '{algorithm_type}' is not registered. "
|
||||
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
|
||||
) from err
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
return cfg.build_algorithm(policy)
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
|
||||
@@ -1,90 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
)
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC algorithm hyperparameters."""
|
||||
|
||||
# Optimizer learning rates
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
|
||||
# Bellman update
|
||||
discount: float = 0.99
|
||||
use_backup_entropy: bool = True
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
||||
# Critic ensemble
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
# Temperature / entropy
|
||||
temperature_init: float = 1.0
|
||||
# Target entropy for automatic temperature tuning. If ``None``, defaults to
|
||||
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
|
||||
# there is a discrete action head).
|
||||
target_entropy: float | None = None
|
||||
|
||||
# Update loop
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Optimizations
|
||||
# torch.compile is currently disabled by default
|
||||
use_torch_compile: bool = False
|
||||
|
||||
# Policy config
|
||||
policy_config: GaussianActorConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||
"""Build an algorithm config with default hyperparameters for a given policy."""
|
||||
return cls(
|
||||
policy_config=policy_cfg,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
if self.policy_config is None:
|
||||
raise ValueError(
|
||||
"SACAlgorithmConfig.policy_config is None. "
|
||||
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
||||
"before calling build_algorithm()."
|
||||
)
|
||||
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm(policy=policy, config=self)
|
||||
@@ -1,595 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
MLP,
|
||||
DiscreteCritic,
|
||||
GaussianActorObservationEncoder,
|
||||
GaussianActorPolicy,
|
||||
orthogonal_init,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import BatchType, RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import TrainingStats
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
|
||||
class SACAlgorithm(RLAlgorithm):
|
||||
"""Soft Actor-Critic. Owns critics, targets, temperature, and loss computation."""
|
||||
|
||||
config_class = SACAlgorithmConfig
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: GaussianActorPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.policy_config = config.policy_config
|
||||
self.policy = policy
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self._init_critics(action_dim)
|
||||
self._init_temperature(action_dim)
|
||||
|
||||
self._device = torch.device(self.policy.config.device)
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critics(self, action_dim) -> None:
|
||||
"""Build critic ensemble, targets."""
|
||||
encoder = self.policy.encoder_critic
|
||||
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
# TODO(Khalil): Investigate and fix torch.compile
|
||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.discrete_critic_target = None
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
|
||||
|
||||
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
|
||||
"""Build target discrete critic (main network is owned by the policy)."""
|
||||
discrete_critic_target = DiscreteCritic(
|
||||
encoder=encoder,
|
||||
input_dim=encoder.output_dim,
|
||||
output_dim=self.policy_config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||
return discrete_critic_target
|
||||
|
||||
def _init_temperature(self, continuous_action_dim: int) -> None:
|
||||
"""Set up temperature parameter (log_alpha) and target entropy."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
total_action_dim = continuous_action_dim + (
|
||||
1 if self.policy_config.num_discrete_actions is not None else 0
|
||||
)
|
||||
self.target_entropy = -total_action_dim / 2
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
self.policy.to(self._device)
|
||||
self.critic_ensemble.to(self._device)
|
||||
self.critic_target.to(self._device)
|
||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||
if self.discrete_critic_target is not None:
|
||||
self.discrete_critic_target.to(self._device)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def _critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def _discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
clip = self.config.grad_clip_norm
|
||||
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch, include_complementary_info=True)
|
||||
|
||||
loss_critic = self._compute_loss_critic(fb)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
|
||||
self.optimizers["discrete_critic"].step()
|
||||
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch, include_complementary_info=False)
|
||||
|
||||
loss_critic = self._compute_loss_critic(fb)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
stats = TrainingStats(
|
||||
losses={"loss_critic": loss_critic.item()},
|
||||
grad_norms={"critic": critic_grad},
|
||||
)
|
||||
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(), max_norm=clip
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
stats.losses["loss_discrete_critic"] = loss_dc.item()
|
||||
stats.grad_norms["discrete_critic"] = dc_grad
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
for _ in range(self.config.policy_update_freq):
|
||||
loss_actor = self._compute_loss_actor(fb)
|
||||
self.optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.actor.parameters(), max_norm=clip
|
||||
).item()
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
loss_temp = self._compute_loss_temperature(fb)
|
||||
self.optimizers["temperature"].zero_grad()
|
||||
loss_temp.backward()
|
||||
temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item()
|
||||
self.optimizers["temperature"].step()
|
||||
|
||||
stats.losses["loss_actor"] = loss_actor.item()
|
||||
stats.losses["loss_temperature"] = loss_temp.item()
|
||||
stats.grad_norms["actor"] = actor_grad
|
||||
stats.grad_norms["temperature"] = temp_grad
|
||||
stats.extra["temperature"] = self.temperature
|
||||
|
||||
self._update_target_networks()
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.policy.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self._critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self._critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self._discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self._discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self._discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
actions_pi, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||
|
||||
q_preds = self._critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_p, p in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
|
||||
):
|
||||
target_p.data.copy_(
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.policy.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_p.data.copy_(
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def _prepare_forward_batch(
|
||||
self, batch: BatchType, *, include_complementary_info: bool = True
|
||||
) -> dict[str, Any]:
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
observation_features, next_observation_features = self.get_observation_features(
|
||||
observations, next_observations
|
||||
)
|
||||
forward_batch: dict[str, Any] = {
|
||||
ACTION: batch[ACTION],
|
||||
"reward": batch["reward"],
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": batch["done"],
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
if include_complementary_info and "complementary_info" in batch:
|
||||
forward_batch["complementary_info"] = batch["complementary_info"]
|
||||
return forward_batch
|
||||
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping component names ("actor", "critic", "temperature")
|
||||
to their respective Adam optimizers.
|
||||
"""
|
||||
actor_params = self.policy.get_optim_params()["actor"]
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Send actor + discrete-critic state dicts."""
|
||||
state_dicts: dict[str, Any] = {
|
||||
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
|
||||
}
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
self.policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
return state_dicts
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load actor + discrete-critic weights into the policy."""
|
||||
self.policy.load_actor_weights(weights, device=device)
|
||||
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = self.policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (GaussianActorObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user