refactor to use relative state

This commit is contained in:
Pepijn
2026-04-01 17:23:58 +02:00
parent 0fc855df13
commit 58bd11caf3
14 changed files with 502 additions and 296 deletions
+16 -1
View File
@@ -202,11 +202,22 @@ Here is how the different processors compose. Each arrow is a processor step, an
└─────────────────────────────────────────┘ └─────────────────────────────────────────┘
┌─────────────────────────────────────────┐ ┌─────────────────────────────────────────┐
Representation │ Absolute ────→ Relative State Derivation │ Action column ────→ State + Action
│ DeriveStateFromActionStep (pre only) │
│ (UMI-style: state from action chunk) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Action Repr. │ Absolute ←────→ Relative │
│ RelativeActionsProcessorStep (pre) │ │ RelativeActionsProcessorStep (pre) │
│ AbsoluteActionsProcessorStep (post) │ │ AbsoluteActionsProcessorStep (post) │
└─────────────────────────────────────────┘ └─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
State Repr. │ Absolute ────→ Relative │
│ RelativeStateProcessorStep (pre only) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐ ┌─────────────────────────────────────────┐
Normalization │ Raw ←────→ Normalized │ Normalization │ Raw ←────→ Normalized │
│ NormalizerProcessorStep (pre) │ │ NormalizerProcessorStep (pre) │
@@ -216,6 +227,10 @@ Here is how the different processors compose. Each arrow is a processor step, an
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`. A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output).
With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details.
## References ## References
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions. - [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
+75 -61
View File
@@ -4,16 +4,13 @@ This guide explains how to prepare a UMI-collected dataset for training a pi0 po
**What we will do:** **What we will do:**
1. How to add `observation.state` to an existing UMI LeRobot dataset. 1. Recompute dataset statistics for relative actions and state.
2. How to train pi0 with `use_relative_actions=True`. 2. Train pi0 with `derive_state_from_action=true` (full UMI pipeline).
3. How to evaluate the trained policy on a real robot. 3. Evaluate the trained policy on a real robot.
## Background ## Background
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need two additions: [UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need **relative action statistics** — so the normalizer sees `(action state)` distributions.
1. **`observation.state`** — the current EE pose the policy conditions on.
2. **Relative action statistics** — so the normalizer sees `(action state)` distributions.
### Why relative actions? ### Why relative actions?
@@ -25,72 +22,39 @@ relative_action[i] = absolute_action[t + i] state[t]
This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison. This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison.
## State-Action Offset ### Full UMI mode: `derive_state_from_action`
UMI SLAM produces a single trajectory of EE poses stored as `action`. We derive `observation.state` from the same trajectory with a configurable offset: When `derive_state_from_action=true`, pi0 automatically derives `observation.state` from the action column on the fly — no separate state column or dataset conversion step needed. Under the hood:
``` - `action_delta_indices` extends to `[-1, 0, 1, ..., 49]` (one extra leading timestep).
state[t] = action[t - offset] - `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk.
``` - `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens.
| Offset | `state[t]` | Meaning | This single flag implies `use_relative_state=true` and `state_obs_steps=2`.
| ------ | ------------- | ---------------------------------------------------------------- |
| 0 | `action[t]` | State and action are the same pose at time t |
| 1 | `action[t-1]` | State is the previous action — where the gripper already arrived |
An offset of 1 is the typical UMI convention: at decision time the "current state" is where the gripper _already is_ (the result of the previous command), and the action is where it should go next. At episode boundaries where `t < offset`, we clamp to `action[0]`. During **inference**, state comes from the robot (via FK), so `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically.
## Step 1: Add `observation.state` ## Step 1: Recompute Stats
pi0 with `use_relative_actions=True` needs `observation.state` in the dataset to compute `action - state` on the fly. The script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` adds it. Edit the constants at the top: Use the built-in CLI to recompute dataset statistics for relative actions and derive-state-from-action:
```python
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Option A: Copy an existing feature as observation.state
STATE_SOURCE_FEATURE = "observation.joints" # or "observation.pose", etc.
# Option B: Derive from action with offset (set STATE_SOURCE_FEATURE = None)
STATE_SOURCE_FEATURE = None
STATE_ACTION_OFFSET = 1
```
**Choosing the state source:**
- If your dataset already has a feature in the same space as `action` (e.g. `observation.joints` for joint-space actions, or `observation.pose` for EE-space actions), set `STATE_SOURCE_FEATURE` to copy it.
- If your dataset only has a single trajectory (like raw UMI EE data where action = the EE poses), set `STATE_SOURCE_FEATURE = None` and use `STATE_ACTION_OFFSET` to derive state from the action column with a time offset.
`observation.state` **must have the same shape as `action`** — the relative conversion computes `action - state` element-wise.
Then run:
```bash
python examples/umi_pi0_relative_ee/convert_umi_dataset.py
```
<Tip>
If your dataset already has `observation.state`, the script exits early — nothing to do.
</Tip>
## Step 2: Recompute Relative Action Stats
Use the built-in CLI to recompute dataset statistics in relative space:
```bash ```bash
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id <your_dataset> \ --repo_id <your_dataset> \
--operation.type recompute_stats \ --operation.type recompute_stats \
--operation.relative_action true \ --operation.relative_action true \
--operation.derive_state_from_action true \
--operation.chunk_size 50 \ --operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \ --operation.relative_exclude_joints "['gripper']" \
--push_to_hub true --push_to_hub true
``` ```
The `derive_state_from_action` flag tells `recompute_stats` to read from the action column (instead of `observation.state`) when computing relative state stats. It automatically enables `relative_state=true` and `state_obs_steps=2`.
The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative. The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative.
## Step 3: Train ## Step 2: Train
No custom training script is needed — standard `lerobot-train` handles everything: No custom training script is needed — standard `lerobot-train` handles everything:
@@ -99,19 +63,26 @@ lerobot-train \
--dataset.repo_id=<hf_username>/<dataset_repo_id> \ --dataset.repo_id=<hf_username>/<dataset_repo_id> \
--policy.type=pi0 \ --policy.type=pi0 \
--policy.pretrained_path=lerobot/pi0_base \ --policy.pretrained_path=lerobot/pi0_base \
--policy.derive_state_from_action=true \
--policy.use_relative_actions=true \ --policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]' --policy.relative_exclude_joints='["gripper"]'
``` ```
`derive_state_from_action=true` auto-enables `use_relative_state=true` and `state_obs_steps=2`.
Under the hood, the training pipeline: Under the hood, the training pipeline:
- Loads relative action stats from the dataset's `meta/stats.json`. - Loads relative action stats and relative state stats from the dataset's `meta/stats.json`.
- Configures `RelativeActionsProcessorStep` in the preprocessor (absolute → relative before normalization). - Extends `action_delta_indices` to `[-1, 0, 1, ..., 49]` to load one extra leading timestep.
- The model trains on normalized relative action values. - `DeriveStateFromActionStep` extracts the 2-step state from the action chunk and strips the extra timestep.
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
- `RelativeStateProcessorStep` converts the 2-step state to relative offsets from the current timestep, then flattens.
- `NormalizerProcessorStep` normalizes everything.
- The model trains on normalized relative values.
See the [pi0 documentation](pi0) for all available training options. See the [pi0 documentation](pi0) for all available training options.
## Step 4: Evaluate ## Step 3: Evaluate
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space): The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space):
@@ -121,10 +92,20 @@ python examples/umi_pi0_relative_ee/evaluate.py
Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file. Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file.
### Latency compensation
For real robot deployment, you may want to skip the first few steps of each predicted action chunk to compensate for system latency. Set `LATENCY_SKIP_STEPS` in the evaluate script:
```python
LATENCY_SKIP_STEPS = 0 # ceil(total_latency_ms / (1000 / FPS))
```
For example, at 10Hz with ~200ms total latency, set `LATENCY_SKIP_STEPS = 2`.
The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed: The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed:
1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`. 1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`.
2. **Preprocessor** — `RelativeActionsProcessorStep` caches the raw `observation.state`, then `NormalizerProcessorStep` normalizes everything. 2. **Preprocessor** — `DeriveStateFromActionStep` is a no-op (state comes from robot). `RelativeStateProcessorStep` buffers previous state, stacks, and converts to relative. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes.
3. **pi0 inference** — The model predicts a normalized relative action chunk. 3. **pi0 inference** — The model predicts a normalized relative action chunk.
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets. 4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets.
5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands. 5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands.
@@ -132,14 +113,22 @@ The inference flow uses pi0's built-in processor pipeline — no custom wrappers
## How the Pieces Fit Together ## How the Pieces Fit Together
``` ```
Training: Training (full UMI mode: derive_state_from_action=true):
dataset (absolute EE) → RelativeActionsProcessorStep → NormalizerProcessorStep → pi0 model DataLoader (action: B,51,D)
→ DeriveStateFromActionStep (state = action[:,:2,:], action = action[:,1:,:])
→ RelativeActionsProcessorStep (action -= state[:,-1,:])
→ RelativeStateProcessorStep (state offsets from current, flatten → B,2*D)
→ NormalizerProcessorStep → pi0 model
Inference: Inference:
robot joints → FK → observation.state (absolute EE) robot joints → FK → observation.state (absolute EE)
DeriveStateFromActionStep (no-op)
RelativeActionsProcessorStep (caches state) RelativeActionsProcessorStep (caches state)
RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens)
NormalizerProcessorStep → pi0 model → relative action chunk NormalizerProcessorStep → pi0 model → relative action chunk
UnnormalizerProcessorStep UnnormalizerProcessorStep
@@ -149,6 +138,31 @@ Inference:
IK → joint targets → robot IK → joint targets → robot
``` ```
## Manual mode (without derive_state_from_action)
If your dataset already has `observation.state` (or you want to add it separately), you can skip `derive_state_from_action` and use relative actions + relative state independently:
```bash
# Recompute stats
lerobot-edit-dataset \
--repo_id <your_dataset> \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.relative_state true \
--operation.state_obs_steps 2 \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']"
# Train
lerobot-train \
--dataset.repo_id=<your_dataset> \
--policy.type=pi0 \
--policy.use_relative_actions=true \
--policy.use_relative_state=true \
--policy.state_obs_steps=2 \
--policy.relative_exclude_joints='["gripper"]'
```
## References ## References
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions. - [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
@@ -1,220 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Add ``observation.state`` to an existing LeRobot dataset.
pi0 uses ``observation.state`` as its proprioceptive input AND for
relative action conversion (action state). This script creates
``observation.state`` by concatenating one or more existing features.
Ordering matters: the features whose dimensions correspond to ``action``
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
pose) are appended after and are seen by the model but not used for
relative conversion.
Example: action = [proximal, distal], and we want the model to also see
the EE pose:
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
observation.state = [proximal, distal, x, y, z, ax, ay, az]
The relative conversion uses state[:2] = [proximal, distal] to subtract
from action[:2], and the model sees all 8 dimensions.
After running this script, recompute relative action stats:
lerobot-edit-dataset \\
--repo_id <your_dataset> \\
--operation.type recompute_stats \\
--operation.relative_action true \\
--operation.chunk_size 50 \\
--operation.relative_exclude_joints "[]" \\
--push_to_hub true
Usage:
python convert_umi_dataset.py
"""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HF_DATASET_ID = ""
# Output repo ID. Set to None for default "<input>_modified".
OUTPUT_REPO_ID: str | None = None
# Features to concatenate into observation.state. Order matters:
# action-matching features FIRST, then extra proprioception.
# Set to a single string to copy one feature directly.
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
# Only used when STATE_SOURCE_FEATURES is None:
# derive state from action with a per-episode offset.
STATE_ACTION_OFFSET = 1
# Push the augmented dataset to the Hugging Face Hub.
PUSH_TO_HUB = True
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
hf = dataset.hf_dataset
ep = np.array(hf["episode_index"])
fr = np.array(hf["frame_index"])
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
"""Concatenate multiple features into observation.state."""
hf = dataset.hf_dataset
key_to_global = _build_global_index(dataset)
columns = [hf[feat] for feat in source_features]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
g = key_to_global[(ep_idx, frame_idx)]
parts = []
for col in columns:
val = col[g]
if hasattr(val, "tolist"):
flat = val.tolist()
if isinstance(flat, list):
parts.extend(flat)
else:
parts.append(flat)
elif isinstance(val, list):
parts.extend(val)
else:
parts.append(float(val))
return parts
return _get_state
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
"""Derive state from action with a per-episode offset."""
hf = dataset.hf_dataset
episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"])
ep_sorted: dict[int, list[tuple[int, int]]] = {}
for ep_idx in np.unique(episode_indices):
ep_mask = episode_indices == ep_idx
ep_globals = np.where(ep_mask)[0]
ep_frames = frame_indices[ep_globals]
order = np.argsort(ep_frames)
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
ep_frame_to_local: dict[int, dict[int, int]] = {}
for ep_idx, sorted_list in ep_sorted.items():
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
actions = hf["action"]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
local_t = ep_frame_to_local[ep_idx][frame_idx]
source_local = max(0, local_t - offset)
_, source_global = ep_sorted[ep_idx][source_local]
return actions[source_global]
return _get_state
def main():
logger.info(f"Loading dataset {HF_DATASET_ID}")
dataset = LeRobotDataset(HF_DATASET_ID)
if "observation.state" in dataset.features:
logger.info("observation.state already exists — nothing to do")
return
action_meta = dataset.features["action"]
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
if STATE_SOURCE_FEATURES is not None:
source_list = (
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
)
for feat in source_list:
if feat not in dataset.features:
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
# Compute combined shape and names
total_dim = 0
all_names = []
for feat in source_list:
meta = dataset.features[feat]
total_dim += meta["shape"][0]
names = meta.get("names")
if names:
all_names.extend(names)
logger.info(
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
)
state_fn = _build_state_from_features(dataset, source_list)
state_feature_info = {
"dtype": "float32",
"shape": [total_dim],
"names": all_names or None,
}
else:
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
state_feature_info = {
"dtype": "float32",
"shape": list(action_meta["shape"]),
"names": action_meta.get("names"),
}
augmented = add_features(
dataset,
features={"observation.state": (state_fn, state_feature_info)},
repo_id=OUTPUT_REPO_ID,
)
logger.info("observation.state added")
if PUSH_TO_HUB:
logger.info(f"Pushing to Hub: {augmented.repo_id}")
augmented.push_to_hub()
logger.info(
f"Done. Dataset at: {augmented.root}\n"
"Now recompute relative action stats:\n"
" lerobot-edit-dataset \\\n"
f" --repo_id {augmented.repo_id} \\\n"
" --operation.type recompute_stats \\\n"
" --operation.relative_action true \\\n"
" --operation.chunk_size 50 \\\n"
' --operation.relative_exclude_joints "[]" \\\n'
" --push_to_hub true"
)
if __name__ == "__main__":
main()
+23 -7
View File
@@ -17,17 +17,20 @@
""" """
Inference script for a pi0 model trained with **relative EE actions**. Inference script for a pi0 model trained with **relative EE actions**.
This uses the built-in ``RelativeActionsProcessorStep`` and This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
``AbsoluteActionsProcessorStep`` that are already wired into pi0's ``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
processor pipeline when ``use_relative_actions=True``. ``RelativeStateProcessorStep`` that are already wired into pi0's processor
pipeline.
The inference loop: The inference loop:
1. Reads joint positions from the robot. 1. Reads joint positions from the robot.
2. Converts to EE pose via forward kinematics (FK). 2. Converts to EE pose via forward kinematics (FK).
This produces ``observation.state`` with the current EE pose. This produces ``observation.state`` with the current EE pose.
3. The pi0 preprocessor: 3. The pi0 preprocessor:
a) ``RelativeActionsProcessorStep`` caches the raw state. a) ``DeriveStateFromActionStep`` no-op (state comes from robot).
b) ``NormalizerProcessorStep`` normalizes state and actions. b) ``RelativeActionsProcessorStep`` caches the raw state.
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
d) ``NormalizerProcessorStep`` normalizes state and actions.
4. pi0 predicts relative action chunk. 4. pi0 predicts relative action chunk.
5. The pi0 postprocessor: 5. The pi0 postprocessor:
a) ``UnnormalizerProcessorStep`` unnormalizes. a) ``UnnormalizerProcessorStep`` unnormalizes.
@@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.processor import ( from lerobot.processor import (
RelativeStateProcessorStep,
RobotProcessorPipeline, RobotProcessorPipeline,
make_default_teleop_action_processor, make_default_teleop_action_processor,
) )
@@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task"
HF_MODEL_ID = "<hf_username>/<model_repo_id>" HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>" HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Latency compensation: skip this many steps from the start of each predicted
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
LATENCY_SKIP_STEPS = 0
# EE feature keys produced by ForwardKinematicsJointsToEE # EE feature keys produced by ForwardKinematicsJointsToEE
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"] EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
@@ -94,6 +103,7 @@ def main():
robot = SO100Follower(robot_config) robot = SO100Follower(robot_config)
policy = PI0Policy.from_pretrained(HF_MODEL_ID) policy = PI0Policy.from_pretrained(HF_MODEL_ID)
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
kinematics_solver = RobotKinematics( kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf", urdf_path="./SO101/so101_new_calib.urdf",
@@ -151,9 +161,8 @@ def main():
# Build pre/post processors from the trained model. # Build pre/post processors from the trained model.
# The pi0 processor pipeline already includes: # The pi0 processor pipeline already includes:
# pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep # pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ... # post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
# These handle the relative ↔ absolute conversion automatically.
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy, policy_cfg=policy,
pretrained_path=HF_MODEL_ID, pretrained_path=HF_MODEL_ID,
@@ -161,6 +170,9 @@ def main():
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
) )
# Find the relative state step (if present) so we can reset its buffer between episodes.
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
robot.connect() robot.connect()
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
@@ -174,6 +186,10 @@ def main():
for episode_idx in range(NUM_EPISODES): for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Reset relative state buffer so velocity is zero at episode start
for step in _relative_state_steps:
step.reset()
record_loop( record_loop(
robot=robot, robot=robot,
events=events, events=events,
+11
View File
@@ -115,6 +115,17 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError raise NotImplementedError
@property
def state_delta_indices(self) -> list | None: # type: ignore[type-arg]
"""Delta indices specifically for observation.state.
When not None, overrides ``observation_delta_indices`` for the
``observation.state`` key only. Useful for loading state history
(e.g. ``[-1, 0]`` for UMI-style relative proprioception) without
also loading multiple image timesteps.
"""
return None
@abc.abstractmethod @abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig: def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError raise NotImplementedError
+91
View File
@@ -767,3 +767,94 @@ def compute_relative_action_stats(
) )
return stats return stats
def compute_relative_state_stats(
hf_dataset,
features: dict,
state_obs_steps: int = 2,
exclude_joints: list[str] | None = None,
source_key: str = OBS_STATE,
) -> dict[str, np.ndarray]:
"""Compute normalization statistics for observation.state after relative conversion.
For UMI-style relative proprioception with ``state_obs_steps`` timesteps,
each state observation becomes a stack of offsets from the current timestep:
``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``.
The stats are computed over the flattened ``[state_obs_steps * state_dim]``
vector that the model actually sees after ``prepare_state`` flattening.
Args:
hf_dataset: The HuggingFace dataset with the source column and
"episode_index" columns.
features: Dataset feature metadata.
state_obs_steps: Number of observation timesteps (must be >= 2).
exclude_joints: State dimension names to keep absolute.
source_key: Column to read data from. Defaults to "observation.state".
When ``derive_state_from_action=True``, pass ``ACTION`` to read
from the action column instead.
Returns:
Statistics dict with keys "mean", "std", "min", "max", "q01", , "q99".
"""
from lerobot.processor.relative_action_processor import RelativeStateProcessorStep
if exclude_joints is None:
exclude_joints = []
state_dim = features[source_key]["shape"][0]
state_names = features.get(source_key, {}).get("names")
mask_step = RelativeStateProcessorStep(
enabled=True,
exclude_joints=exclude_joints,
state_names=state_names,
)
relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32)
logging.info(f"Loading data from '{source_key}' for relative state stats...")
all_states = np.array(hf_dataset[source_key], dtype=np.float32)
episode_indices = np.array(hf_dataset["episode_index"])
# Build all valid windows of length state_obs_steps within each episode
n = len(all_states)
if n < state_obs_steps:
raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}")
max_start = n - state_obs_steps
starts = np.arange(max_start + 1)
valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1]
valid_starts = starts[valid]
if len(valid_starts) == 0:
raise RuntimeError("No valid state windows found within single episodes")
offsets = np.arange(state_obs_steps)
mask_dim = len(relative_mask)
running_stats = RunningQuantileStats()
batch_size = 50_000
for i in range(0, len(valid_starts), batch_size):
batch_starts = valid_starts[i : i + batch_size]
frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps]
windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim]
# Subtract current (last) timestep from all timesteps for masked dims
current = windows[:, -1:, :] # [N, 1, state_dim]
windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :]
# Flatten to [N, state_obs_steps * state_dim] (same as prepare_state)
flattened = windows.reshape(len(batch_starts), -1)
running_stats.update(flattened)
stats = running_stats.get_statistics()
excluded_dims = int(mask_dim - relative_mask.sum())
logging.info(
f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): "
f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), "
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}"
)
return stats
+34
View File
@@ -41,6 +41,7 @@ from lerobot.datasets.compute_stats import (
aggregate_stats, aggregate_stats,
compute_episode_stats, compute_episode_stats,
compute_relative_action_stats, compute_relative_action_stats,
compute_relative_state_stats,
) )
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.io_utils import ( from lerobot.datasets.io_utils import (
@@ -1544,6 +1545,10 @@ def recompute_stats(
relative_exclude_joints: list[str] | None = None, relative_exclude_joints: list[str] | None = None,
chunk_size: int = 50, chunk_size: int = 50,
num_workers: int = 0, num_workers: int = 0,
relative_state: bool = False,
relative_exclude_state_joints: list[str] | None = None,
state_obs_steps: int = 2,
derive_state_from_action: bool = False,
) -> LeRobotDataset: ) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes. """Recompute stats.json from scratch by iterating all episodes.
@@ -1561,10 +1566,22 @@ def recompute_stats(
``policy.chunk_size``. Only used when ``relative_action=True``. ``policy.chunk_size``. Only used when ``relative_action=True``.
num_workers: Number of parallel threads for relative action stats computation. num_workers: Number of parallel threads for relative action stats computation.
Values 1 mean single-threaded. Only used when ``relative_action=True``. Values 1 mean single-threaded. Only used when ``relative_action=True``.
relative_state: If True, compute observation.state stats in relative space
(multi-timestep offsets from current). This matches the normalization
the model sees during training with ``use_relative_state=True``.
relative_exclude_state_joints: State dim names to exclude from relative conversion.
state_obs_steps: Number of observation timesteps for relative state stats.
Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``.
derive_state_from_action: If True, compute relative state stats from the
action column instead of observation.state. Implies ``relative_state=True``
and ``state_obs_steps=2``.
Returns: Returns:
The same dataset with updated stats. The same dataset with updated stats.
""" """
if derive_state_from_action:
relative_state = True
state_obs_steps = 2
features = dataset.meta.features features = dataset.meta.features
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"} meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
numeric_features = { numeric_features = {
@@ -1596,6 +1613,20 @@ def recompute_stats(
) )
features_to_compute.pop(ACTION, None) features_to_compute.pop(ACTION, None)
# When relative_state is enabled, compute state stats over the flattened
# multi-timestep relative representation (matching what the model sees).
relative_state_stats = None
if relative_state and (OBS_STATE in features or derive_state_from_action):
source_key = ACTION if derive_state_from_action else OBS_STATE
relative_state_stats = compute_relative_state_stats(
hf_dataset=dataset.hf_dataset,
features=features,
state_obs_steps=state_obs_steps,
exclude_joints=relative_exclude_state_joints,
source_key=source_key,
)
features_to_compute.pop(OBS_STATE, None)
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}") logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR data_dir = dataset.root / DATA_DIR
@@ -1632,6 +1663,9 @@ def recompute_stats(
if relative_action_stats is not None: if relative_action_stats is not None:
new_stats[ACTION] = relative_action_stats new_stats[ACTION] = relative_action_stats
if relative_state_stats is not None:
new_stats[OBS_STATE] = relative_state_stats
# Merge: keep existing stats for features we didn't recompute # Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats: if dataset.meta.stats:
for key, value in dataset.meta.stats.items(): for key, value in dataset.meta.stats.items():
+5 -2
View File
@@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD
IMAGENET_STATS = { IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -52,12 +52,15 @@ def resolve_delta_timestamps(
returns `None` if the resulting dict is empty. returns `None` if the resulting dict is empty.
""" """
delta_timestamps = {} delta_timestamps = {}
state_delta = getattr(cfg, "state_delta_indices", None)
for key in ds_meta.features: for key in ds_meta.features:
if key == REWARD and cfg.reward_delta_indices is not None: if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None: if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: if key == OBS_STATE and state_delta is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in state_delta]
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0: if len(delta_timestamps) == 0:
@@ -57,6 +57,28 @@ class PI0Config(PreTrainedConfig):
# Populated at runtime from dataset metadata by make_policy. # Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None action_feature_names: list[str] | None = None
# Relative state (UMI-style relative proprioception): converts multi-timestep
# observation.state to offsets from the current timestep, providing velocity info.
# Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to
# max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim.
use_relative_state: bool = False
state_obs_steps: int = 1
relative_exclude_state_joints: list[str] = field(default_factory=list)
# Populated at runtime from dataset metadata by make_policy.
state_feature_names: list[str] | None = None
# Derive observation.state from the action column (UMI-style).
# When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1],
# DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state,
# and strips the extra timestep from the action chunk.
# Implies use_relative_state=True and state_obs_steps=2.
derive_state_from_action: bool = False
# Latency compensation: skip this many steps from the start of each predicted
# action chunk during inference. E.g. at 10Hz with ~200ms total latency,
# latency_skip_steps=2 compensates for the delay.
latency_skip_steps: int = 0
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None rtc_config: RTCConfig | None = None
@@ -106,6 +128,10 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.derive_state_from_action:
self.use_relative_state = True
self.state_obs_steps = 2
# Validate configuration # Validate configuration
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(
@@ -121,6 +147,13 @@ class PI0Config(PreTrainedConfig):
if self.dtype not in ["bfloat16", "float32"]: if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}") raise ValueError(f"Invalid dtype: {self.dtype}")
if self.use_relative_state and self.state_obs_steps < 2:
raise ValueError(
"use_relative_state requires state_obs_steps >= 2 "
f"(got {self.state_obs_steps}). Set state_obs_steps=2 for "
"UMI-style relative proprioception."
)
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
@@ -166,8 +199,16 @@ class PI0Config(PreTrainedConfig):
def observation_delta_indices(self) -> None: def observation_delta_indices(self) -> None:
return None return None
@property
def state_delta_indices(self) -> list[int] | None:
if self.state_obs_steps >= 2:
return list(range(-(self.state_obs_steps - 1), 1))
return None
@property @property
def action_delta_indices(self) -> list: def action_delta_indices(self) -> list:
if self.derive_state_from_action:
return [-1] + list(range(self.chunk_size))
return list(range(self.chunk_size)) return list(range(self.chunk_size))
@property @property
+7 -3
View File
@@ -1230,8 +1230,11 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks return images, img_masks
def prepare_state(self, batch): def prepare_state(self, batch):
"""Pad state""" """Flatten multi-timestep state and pad to max_state_dim."""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) state = batch[OBS_STATE]
if state.ndim == 3:
state = state.flatten(start_dim=1)
state = pad_vector(state, self.config.max_state_dim)
return state return state
def prepare_action(self, batch): def prepare_action(self, batch):
@@ -1250,7 +1253,8 @@ class PI0Policy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1 # Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0: if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] skip = self.config.latency_skip_steps
actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim) # Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1)) self._action_queue.extend(actions.transpose(0, 1))
+17 -1
View File
@@ -24,6 +24,7 @@ from lerobot.processor import (
AbsoluteActionsProcessorStep, AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep, ComplementaryDataProcessorStep,
DeriveStateFromActionStep,
DeviceProcessorStep, DeviceProcessorStep,
NormalizerProcessorStep, NormalizerProcessorStep,
PolicyAction, PolicyAction,
@@ -31,6 +32,7 @@ from lerobot.processor import (
ProcessorStep, ProcessorStep,
ProcessorStepRegistry, ProcessorStepRegistry,
RelativeActionsProcessorStep, RelativeActionsProcessorStep,
RelativeStateProcessorStep,
RenameObservationsProcessorStep, RenameObservationsProcessorStep,
TokenizerProcessorStep, TokenizerProcessorStep,
UnnormalizerProcessorStep, UnnormalizerProcessorStep,
@@ -128,13 +130,25 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines. A tuple containing the configured pre-processor and post-processor pipelines.
""" """
derive_state_step = DeriveStateFromActionStep(
enabled=getattr(config, "derive_state_from_action", False),
)
relative_step = RelativeActionsProcessorStep( relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions, enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []), exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None), action_names=getattr(config, "action_feature_names", None),
) )
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute relative_state_step = RelativeStateProcessorStep(
enabled=getattr(config, "use_relative_state", False),
exclude_joints=getattr(config, "relative_exclude_state_joints", []),
state_names=getattr(config, "state_feature_names", None),
)
# Order: DeriveStateFromAction extracts state from the extended action chunk,
# then relative_action uses current state[t] for subtraction,
# then relative_state converts the multi-timestep state to offsets.
input_steps: list[ProcessorStep] = [ input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
@@ -146,7 +160,9 @@ def make_pi0_pre_post_processors(
padding="max_length", padding="max_length",
), ),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
derive_state_step,
relative_step, relative_step,
relative_state_step,
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
+6
View File
@@ -77,9 +77,12 @@ from .policy_robot_bridge import (
) )
from .relative_action_processor import ( from .relative_action_processor import (
AbsoluteActionsProcessorStep, AbsoluteActionsProcessorStep,
DeriveStateFromActionStep,
RelativeActionsProcessorStep, RelativeActionsProcessorStep,
RelativeStateProcessorStep,
to_absolute_actions, to_absolute_actions,
to_relative_actions, to_relative_actions,
to_relative_state,
) )
from .rename_processor import RenameObservationsProcessorStep from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
@@ -107,7 +110,9 @@ __all__ = [
"make_default_robot_action_processor", "make_default_robot_action_processor",
"make_default_robot_observation_processor", "make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep", "AbsoluteActionsProcessorStep",
"DeriveStateFromActionStep",
"RelativeActionsProcessorStep", "RelativeActionsProcessorStep",
"RelativeStateProcessorStep",
"MapDeltaActionToRobotActionStep", "MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep", "MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep", "NormalizerProcessorStep",
@@ -139,6 +144,7 @@ __all__ = [
"TruncatedProcessorStep", "TruncatedProcessorStep",
"to_absolute_actions", "to_absolute_actions",
"to_relative_actions", "to_relative_actions",
"to_relative_state",
"UnnormalizerProcessorStep", "UnnormalizerProcessorStep",
"VanillaObservationProcessorStep", "VanillaObservationProcessorStep",
] ]
@@ -30,10 +30,13 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
__all__ = [ __all__ = [
"MapDeltaActionToRobotActionStep", "MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep", "MapTensorToDeltaActionDictStep",
"DeriveStateFromActionStep",
"RelativeActionsProcessorStep", "RelativeActionsProcessorStep",
"AbsoluteActionsProcessorStep", "AbsoluteActionsProcessorStep",
"RelativeStateProcessorStep",
"to_relative_actions", "to_relative_actions",
"to_absolute_actions", "to_absolute_actions",
"to_relative_state",
] ]
@@ -81,6 +84,41 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
return actions return actions
@ProcessorStepRegistry.register("derive_state_from_action_processor")
@dataclass
class DeriveStateFromActionStep(ProcessorStep):
"""Derives 2-step observation.state from the action chunk (UMI-style).
Expects action with one extra leading timestep: [B, chunk_size+1, D]
from action_delta_indices = [-1, 0, 1, ..., chunk_size-1].
Extracts [action[t-1], action[t]] as state and strips the extra timestep.
No-op during inference (state comes from robot).
"""
enabled: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
action = transition.get(TransitionKey.ACTION)
if action is None or action.ndim < 3:
return transition
new_transition = transition.copy()
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
new_obs[OBS_STATE] = action[..., :2, :]
new_transition[TransitionKey.ACTION] = action[..., 1:, :]
new_transition[TransitionKey.OBSERVATION] = new_obs
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("delta_actions_processor") @ProcessorStepRegistry.register("delta_actions_processor")
@dataclass @dataclass
class RelativeActionsProcessorStep(ProcessorStep): class RelativeActionsProcessorStep(ProcessorStep):
@@ -124,7 +162,14 @@ class RelativeActionsProcessorStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {}) observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None raw_state = observation.get(OBS_STATE) if observation else None
# When state_delta_indices loads multi-timestep state [B, n_obs, D],
# use only the current (last) timestep for relative action conversion.
if raw_state is not None:
state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state
else:
state = None
# Always cache state for the paired AbsoluteActionsProcessorStep # Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None: if state is not None:
@@ -155,6 +200,120 @@ class RelativeActionsProcessorStep(ProcessorStep):
return features return features
def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert multi-timestep absolute state to relative (offset from current timestep).
Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims.
The last (current) timestep becomes zeros for masked dims.
Args:
state: (..., n_obs, state_dim) last timestep is the reference (current).
mask: Which dims to convert. Can be shorter than state_dim.
"""
mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device)
dims = mask_t.shape[0]
current = state[..., -1:, :] # (..., 1, state_dim)
state = state.clone()
state[..., :dims] -= current[..., :dims] * mask_t
return state
@ProcessorStepRegistry.register("relative_state_processor")
@dataclass
class RelativeStateProcessorStep(ProcessorStep):
"""Converts observation.state to relative (offset from current timestep).
UMI-style relative proprioception: each state timestep is expressed as
an offset from the current EE pose, providing velocity information.
During training (multi-timestep input from ``state_delta_indices``):
``state[..., t, :] -= state[..., -1, :]`` subtract current from all.
During inference (single timestep): buffers the previous state and stacks
``[previous, current]`` before applying the relative conversion, producing
the same ``[n_obs, D]`` shape the model expects.
Attributes:
enabled: Whether to apply the relative conversion.
exclude_joints: Joint/dim names to keep absolute.
state_names: State dimension names from dataset metadata.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
state_names: list[str] | None = None
_previous_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, state_dim: int) -> list[bool]:
if not self.exclude_joints or self.state_names is None:
return [True] * state_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * state_dim
mask = []
for name in self.state_names[:state_dim]:
state_name = str(name).lower()
is_excluded = any(token == state_name or token in state_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < state_dim:
mask.extend([True] * (state_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
if state is None:
return transition
new_transition = transition.copy()
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
mask = self._build_mask(state.shape[-1])
if state.ndim >= 3:
# [B, n_obs, D] — multi-timestep (training with state_delta_indices)
relative = to_relative_state(state, mask)
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D]
elif state.ndim == 2:
# [B, D] — single timestep (inference): buffer previous and stack
current = state
if self._previous_state is None:
self._previous_state = current.clone()
prev = self._previous_state
if prev.device != current.device or prev.dtype != current.dtype:
prev = prev.to(device=current.device, dtype=current.dtype)
stacked = torch.stack([prev, current], dim=-2) # [B, 2, D]
relative = to_relative_state(stacked, mask)
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D]
self._previous_state = current.clone()
new_transition[TransitionKey.OBSERVATION] = new_obs
return new_transition
def reset(self) -> None:
"""Reset the state buffer. Call at episode boundaries during inference."""
self._previous_state = None
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"exclude_joints": self.exclude_joints,
"state_names": self.state_names,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor") @ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass @dataclass
class AbsoluteActionsProcessorStep(ProcessorStep): class AbsoluteActionsProcessorStep(ProcessorStep):
@@ -254,6 +254,10 @@ class RecomputeStatsConfig(OperationConfig):
relative_exclude_joints: list[str] | None = None relative_exclude_joints: list[str] | None = None
chunk_size: int = 50 chunk_size: int = 50
num_workers: int = 0 num_workers: int = 0
relative_state: bool = False
relative_exclude_state_joints: list[str] | None = None
state_obs_steps: int = 2
derive_state_from_action: bool = False
@OperationConfig.register_subclass("info") @OperationConfig.register_subclass("info")
@@ -563,6 +567,14 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, " f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
f"exclude_joints={cfg.operation.relative_exclude_joints})" f"exclude_joints={cfg.operation.relative_exclude_joints})"
) )
if cfg.operation.relative_state:
logging.info(
f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, "
f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})"
)
if cfg.operation.derive_state_from_action:
logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)")
recompute_stats( recompute_stats(
dataset, dataset,
@@ -571,6 +583,10 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
relative_exclude_joints=cfg.operation.relative_exclude_joints, relative_exclude_joints=cfg.operation.relative_exclude_joints,
chunk_size=cfg.operation.chunk_size, chunk_size=cfg.operation.chunk_size,
num_workers=cfg.operation.num_workers, num_workers=cfg.operation.num_workers,
relative_state=cfg.operation.relative_state,
relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints,
state_obs_steps=cfg.operation.state_obs_steps,
derive_state_from_action=cfg.operation.derive_state_from_action,
) )
logging.info(f"Stats written to {dataset.root}") logging.info(f"Stats written to {dataset.root}")