refactor to use relative state

This commit is contained in:
Pepijn
2026-04-01 17:23:58 +02:00
parent 0fc855df13
commit 58bd11caf3
14 changed files with 502 additions and 296 deletions
+16 -1
View File
@@ -202,11 +202,22 @@ Here is how the different processors compose. Each arrow is a processor step, an
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Representation │ Absolute ────→ Relative
State Derivation │ Action column ────→ State + Action
│ DeriveStateFromActionStep (pre only) │
│ (UMI-style: state from action chunk) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Action Repr. │ Absolute ←────→ Relative │
│ RelativeActionsProcessorStep (pre) │
│ AbsoluteActionsProcessorStep (post) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
State Repr. │ Absolute ────→ Relative │
│ RelativeStateProcessorStep (pre only) │
└─────────────────────────────────────────┘
┌─────────────────────────────────────────┐
Normalization │ Raw ←────→ Normalized │
│ NormalizerProcessorStep (pre) │
@@ -216,6 +227,10 @@ Here is how the different processors compose. Each arrow is a processor step, an
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output).
With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details.
## References
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
+75 -61
View File
@@ -4,16 +4,13 @@ This guide explains how to prepare a UMI-collected dataset for training a pi0 po
**What we will do:**
1. How to add `observation.state` to an existing UMI LeRobot dataset.
2. How to train pi0 with `use_relative_actions=True`.
3. How to evaluate the trained policy on a real robot.
1. Recompute dataset statistics for relative actions and state.
2. Train pi0 with `derive_state_from_action=true` (full UMI pipeline).
3. Evaluate the trained policy on a real robot.
## Background
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need two additions:
1. **`observation.state`** — the current EE pose the policy conditions on.
2. **Relative action statistics** — so the normalizer sees `(action state)` distributions.
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need **relative action statistics** — so the normalizer sees `(action state)` distributions.
### Why relative actions?
@@ -25,72 +22,39 @@ relative_action[i] = absolute_action[t + i] state[t]
This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison.
## State-Action Offset
### Full UMI mode: `derive_state_from_action`
UMI SLAM produces a single trajectory of EE poses stored as `action`. We derive `observation.state` from the same trajectory with a configurable offset:
When `derive_state_from_action=true`, pi0 automatically derives `observation.state` from the action column on the fly — no separate state column or dataset conversion step needed. Under the hood:
```
state[t] = action[t - offset]
```
- `action_delta_indices` extends to `[-1, 0, 1, ..., 49]` (one extra leading timestep).
- `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk.
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens.
| Offset | `state[t]` | Meaning |
| ------ | ------------- | ---------------------------------------------------------------- |
| 0 | `action[t]` | State and action are the same pose at time t |
| 1 | `action[t-1]` | State is the previous action — where the gripper already arrived |
This single flag implies `use_relative_state=true` and `state_obs_steps=2`.
An offset of 1 is the typical UMI convention: at decision time the "current state" is where the gripper _already is_ (the result of the previous command), and the action is where it should go next. At episode boundaries where `t < offset`, we clamp to `action[0]`.
During **inference**, state comes from the robot (via FK), so `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically.
## Step 1: Add `observation.state`
## Step 1: Recompute Stats
pi0 with `use_relative_actions=True` needs `observation.state` in the dataset to compute `action - state` on the fly. The script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` adds it. Edit the constants at the top:
```python
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Option A: Copy an existing feature as observation.state
STATE_SOURCE_FEATURE = "observation.joints" # or "observation.pose", etc.
# Option B: Derive from action with offset (set STATE_SOURCE_FEATURE = None)
STATE_SOURCE_FEATURE = None
STATE_ACTION_OFFSET = 1
```
**Choosing the state source:**
- If your dataset already has a feature in the same space as `action` (e.g. `observation.joints` for joint-space actions, or `observation.pose` for EE-space actions), set `STATE_SOURCE_FEATURE` to copy it.
- If your dataset only has a single trajectory (like raw UMI EE data where action = the EE poses), set `STATE_SOURCE_FEATURE = None` and use `STATE_ACTION_OFFSET` to derive state from the action column with a time offset.
`observation.state` **must have the same shape as `action`** — the relative conversion computes `action - state` element-wise.
Then run:
```bash
python examples/umi_pi0_relative_ee/convert_umi_dataset.py
```
<Tip>
If your dataset already has `observation.state`, the script exits early — nothing to do.
</Tip>
## Step 2: Recompute Relative Action Stats
Use the built-in CLI to recompute dataset statistics in relative space:
Use the built-in CLI to recompute dataset statistics for relative actions and derive-state-from-action:
```bash
lerobot-edit-dataset \
--repo_id <your_dataset> \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.derive_state_from_action true \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']" \
--push_to_hub true
```
The `derive_state_from_action` flag tells `recompute_stats` to read from the action column (instead of `observation.state`) when computing relative state stats. It automatically enables `relative_state=true` and `state_obs_steps=2`.
The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative.
## Step 3: Train
## Step 2: Train
No custom training script is needed — standard `lerobot-train` handles everything:
@@ -99,19 +63,26 @@ lerobot-train \
--dataset.repo_id=<hf_username>/<dataset_repo_id> \
--policy.type=pi0 \
--policy.pretrained_path=lerobot/pi0_base \
--policy.derive_state_from_action=true \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]'
```
`derive_state_from_action=true` auto-enables `use_relative_state=true` and `state_obs_steps=2`.
Under the hood, the training pipeline:
- Loads relative action stats from the dataset's `meta/stats.json`.
- Configures `RelativeActionsProcessorStep` in the preprocessor (absolute → relative before normalization).
- The model trains on normalized relative action values.
- Loads relative action stats and relative state stats from the dataset's `meta/stats.json`.
- Extends `action_delta_indices` to `[-1, 0, 1, ..., 49]` to load one extra leading timestep.
- `DeriveStateFromActionStep` extracts the 2-step state from the action chunk and strips the extra timestep.
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
- `RelativeStateProcessorStep` converts the 2-step state to relative offsets from the current timestep, then flattens.
- `NormalizerProcessorStep` normalizes everything.
- The model trains on normalized relative values.
See the [pi0 documentation](pi0) for all available training options.
## Step 4: Evaluate
## Step 3: Evaluate
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space):
@@ -121,10 +92,20 @@ python examples/umi_pi0_relative_ee/evaluate.py
Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file.
### Latency compensation
For real robot deployment, you may want to skip the first few steps of each predicted action chunk to compensate for system latency. Set `LATENCY_SKIP_STEPS` in the evaluate script:
```python
LATENCY_SKIP_STEPS = 0 # ceil(total_latency_ms / (1000 / FPS))
```
For example, at 10Hz with ~200ms total latency, set `LATENCY_SKIP_STEPS = 2`.
The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed:
1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`.
2. **Preprocessor** — `RelativeActionsProcessorStep` caches the raw `observation.state`, then `NormalizerProcessorStep` normalizes everything.
2. **Preprocessor** — `DeriveStateFromActionStep` is a no-op (state comes from robot). `RelativeStateProcessorStep` buffers previous state, stacks, and converts to relative. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes.
3. **pi0 inference** — The model predicts a normalized relative action chunk.
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets.
5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands.
@@ -132,14 +113,22 @@ The inference flow uses pi0's built-in processor pipeline — no custom wrappers
## How the Pieces Fit Together
```
Training:
dataset (absolute EE) → RelativeActionsProcessorStep → NormalizerProcessorStep → pi0 model
Training (full UMI mode: derive_state_from_action=true):
DataLoader (action: B,51,D)
→ DeriveStateFromActionStep (state = action[:,:2,:], action = action[:,1:,:])
→ RelativeActionsProcessorStep (action -= state[:,-1,:])
→ RelativeStateProcessorStep (state offsets from current, flatten → B,2*D)
→ NormalizerProcessorStep → pi0 model
Inference:
robot joints → FK → observation.state (absolute EE)
DeriveStateFromActionStep (no-op)
RelativeActionsProcessorStep (caches state)
RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens)
NormalizerProcessorStep → pi0 model → relative action chunk
UnnormalizerProcessorStep
@@ -149,6 +138,31 @@ Inference:
IK → joint targets → robot
```
## Manual mode (without derive_state_from_action)
If your dataset already has `observation.state` (or you want to add it separately), you can skip `derive_state_from_action` and use relative actions + relative state independently:
```bash
# Recompute stats
lerobot-edit-dataset \
--repo_id <your_dataset> \
--operation.type recompute_stats \
--operation.relative_action true \
--operation.relative_state true \
--operation.state_obs_steps 2 \
--operation.chunk_size 50 \
--operation.relative_exclude_joints "['gripper']"
# Train
lerobot-train \
--dataset.repo_id=<your_dataset> \
--policy.type=pi0 \
--policy.use_relative_actions=true \
--policy.use_relative_state=true \
--policy.state_obs_steps=2 \
--policy.relative_exclude_joints='["gripper"]'
```
## References
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
@@ -1,220 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Add ``observation.state`` to an existing LeRobot dataset.
pi0 uses ``observation.state`` as its proprioceptive input AND for
relative action conversion (action state). This script creates
``observation.state`` by concatenating one or more existing features.
Ordering matters: the features whose dimensions correspond to ``action``
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
pose) are appended after and are seen by the model but not used for
relative conversion.
Example: action = [proximal, distal], and we want the model to also see
the EE pose:
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
The relative conversion uses state[:2] = [proximal, distal] to subtract
from action[:2], and the model sees all 8 dimensions.
After running this script, recompute relative action stats:
lerobot-edit-dataset \\
--repo_id <your_dataset> \\
--operation.type recompute_stats \\
--operation.relative_action true \\
--operation.chunk_size 50 \\
--operation.relative_exclude_joints "[]" \\
--push_to_hub true
Usage:
python convert_umi_dataset.py
"""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HF_DATASET_ID = ""
# Output repo ID. Set to None for default "<input>_modified".
OUTPUT_REPO_ID: str | None = None
# Features to concatenate into observation.state. Order matters:
# action-matching features FIRST, then extra proprioception.
# Set to a single string to copy one feature directly.
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
# Only used when STATE_SOURCE_FEATURES is None:
# derive state from action with a per-episode offset.
STATE_ACTION_OFFSET = 1
# Push the augmented dataset to the Hugging Face Hub.
PUSH_TO_HUB = True
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
hf = dataset.hf_dataset
ep = np.array(hf["episode_index"])
fr = np.array(hf["frame_index"])
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
"""Concatenate multiple features into observation.state."""
hf = dataset.hf_dataset
key_to_global = _build_global_index(dataset)
columns = [hf[feat] for feat in source_features]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
g = key_to_global[(ep_idx, frame_idx)]
parts = []
for col in columns:
val = col[g]
if hasattr(val, "tolist"):
flat = val.tolist()
if isinstance(flat, list):
parts.extend(flat)
else:
parts.append(flat)
elif isinstance(val, list):
parts.extend(val)
else:
parts.append(float(val))
return parts
return _get_state
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
"""Derive state from action with a per-episode offset."""
hf = dataset.hf_dataset
episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"])
ep_sorted: dict[int, list[tuple[int, int]]] = {}
for ep_idx in np.unique(episode_indices):
ep_mask = episode_indices == ep_idx
ep_globals = np.where(ep_mask)[0]
ep_frames = frame_indices[ep_globals]
order = np.argsort(ep_frames)
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
ep_frame_to_local: dict[int, dict[int, int]] = {}
for ep_idx, sorted_list in ep_sorted.items():
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
actions = hf["action"]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
local_t = ep_frame_to_local[ep_idx][frame_idx]
source_local = max(0, local_t - offset)
_, source_global = ep_sorted[ep_idx][source_local]
return actions[source_global]
return _get_state
def main():
logger.info(f"Loading dataset {HF_DATASET_ID}")
dataset = LeRobotDataset(HF_DATASET_ID)
if "observation.state" in dataset.features:
logger.info("observation.state already exists — nothing to do")
return
action_meta = dataset.features["action"]
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
if STATE_SOURCE_FEATURES is not None:
source_list = (
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
)
for feat in source_list:
if feat not in dataset.features:
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
# Compute combined shape and names
total_dim = 0
all_names = []
for feat in source_list:
meta = dataset.features[feat]
total_dim += meta["shape"][0]
names = meta.get("names")
if names:
all_names.extend(names)
logger.info(
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
)
state_fn = _build_state_from_features(dataset, source_list)
state_feature_info = {
"dtype": "float32",
"shape": [total_dim],
"names": all_names or None,
}
else:
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
state_feature_info = {
"dtype": "float32",
"shape": list(action_meta["shape"]),
"names": action_meta.get("names"),
}
augmented = add_features(
dataset,
features={"observation.state": (state_fn, state_feature_info)},
repo_id=OUTPUT_REPO_ID,
)
logger.info("observation.state added")
if PUSH_TO_HUB:
logger.info(f"Pushing to Hub: {augmented.repo_id}")
augmented.push_to_hub()
logger.info(
f"Done. Dataset at: {augmented.root}\n"
"Now recompute relative action stats:\n"
" lerobot-edit-dataset \\\n"
f" --repo_id {augmented.repo_id} \\\n"
" --operation.type recompute_stats \\\n"
" --operation.relative_action true \\\n"
" --operation.chunk_size 50 \\\n"
' --operation.relative_exclude_joints "[]" \\\n'
" --push_to_hub true"
)
if __name__ == "__main__":
main()
+23 -7
View File
@@ -17,17 +17,20 @@
"""
Inference script for a pi0 model trained with **relative EE actions**.
This uses the built-in ``RelativeActionsProcessorStep`` and
``AbsoluteActionsProcessorStep`` that are already wired into pi0's
processor pipeline when ``use_relative_actions=True``.
This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
``RelativeStateProcessorStep`` that are already wired into pi0's processor
pipeline.
The inference loop:
1. Reads joint positions from the robot.
2. Converts to EE pose via forward kinematics (FK).
This produces ``observation.state`` with the current EE pose.
3. The pi0 preprocessor:
a) ``RelativeActionsProcessorStep`` caches the raw state.
b) ``NormalizerProcessorStep`` normalizes state and actions.
a) ``DeriveStateFromActionStep`` — no-op (state comes from robot).
b) ``RelativeActionsProcessorStep`` caches the raw state.
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
d) ``NormalizerProcessorStep`` normalizes state and actions.
4. pi0 predicts relative action chunk.
5. The pi0 postprocessor:
a) ``UnnormalizerProcessorStep`` unnormalizes.
@@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.processor import (
RelativeStateProcessorStep,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
@@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Latency compensation: skip this many steps from the start of each predicted
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
LATENCY_SKIP_STEPS = 0
# EE feature keys produced by ForwardKinematicsJointsToEE
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
@@ -94,6 +103,7 @@ def main():
robot = SO100Follower(robot_config)
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
@@ -151,9 +161,8 @@ def main():
# Build pre/post processors from the trained model.
# The pi0 processor pipeline already includes:
# pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep
# pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
# These handle the relative ↔ absolute conversion automatically.
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
@@ -161,6 +170,9 @@ def main():
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Find the relative state step (if present) so we can reset its buffer between episodes.
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
robot.connect()
listener, events = init_keyboard_listener()
@@ -174,6 +186,10 @@ def main():
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Reset relative state buffer so velocity is zero at episode start
for step in _relative_state_steps:
step.reset()
record_loop(
robot=robot,
events=events,
+11
View File
@@ -115,6 +115,17 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
def state_delta_indices(self) -> list | None: # type: ignore[type-arg]
"""Delta indices specifically for observation.state.
When not None, overrides ``observation_delta_indices`` for the
``observation.state`` key only. Useful for loading state history
(e.g. ``[-1, 0]`` for UMI-style relative proprioception) without
also loading multiple image timesteps.
"""
return None
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
+91
View File
@@ -767,3 +767,94 @@ def compute_relative_action_stats(
)
return stats
def compute_relative_state_stats(
hf_dataset,
features: dict,
state_obs_steps: int = 2,
exclude_joints: list[str] | None = None,
source_key: str = OBS_STATE,
) -> dict[str, np.ndarray]:
"""Compute normalization statistics for observation.state after relative conversion.
For UMI-style relative proprioception with ``state_obs_steps`` timesteps,
each state observation becomes a stack of offsets from the current timestep:
``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``.
The stats are computed over the flattened ``[state_obs_steps * state_dim]``
vector that the model actually sees after ``prepare_state`` flattening.
Args:
hf_dataset: The HuggingFace dataset with the source column and
"episode_index" columns.
features: Dataset feature metadata.
state_obs_steps: Number of observation timesteps (must be >= 2).
exclude_joints: State dimension names to keep absolute.
source_key: Column to read data from. Defaults to "observation.state".
When ``derive_state_from_action=True``, pass ``ACTION`` to read
from the action column instead.
Returns:
Statistics dict with keys "mean", "std", "min", "max", "q01", , "q99".
"""
from lerobot.processor.relative_action_processor import RelativeStateProcessorStep
if exclude_joints is None:
exclude_joints = []
state_dim = features[source_key]["shape"][0]
state_names = features.get(source_key, {}).get("names")
mask_step = RelativeStateProcessorStep(
enabled=True,
exclude_joints=exclude_joints,
state_names=state_names,
)
relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32)
logging.info(f"Loading data from '{source_key}' for relative state stats...")
all_states = np.array(hf_dataset[source_key], dtype=np.float32)
episode_indices = np.array(hf_dataset["episode_index"])
# Build all valid windows of length state_obs_steps within each episode
n = len(all_states)
if n < state_obs_steps:
raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}")
max_start = n - state_obs_steps
starts = np.arange(max_start + 1)
valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1]
valid_starts = starts[valid]
if len(valid_starts) == 0:
raise RuntimeError("No valid state windows found within single episodes")
offsets = np.arange(state_obs_steps)
mask_dim = len(relative_mask)
running_stats = RunningQuantileStats()
batch_size = 50_000
for i in range(0, len(valid_starts), batch_size):
batch_starts = valid_starts[i : i + batch_size]
frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps]
windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim]
# Subtract current (last) timestep from all timesteps for masked dims
current = windows[:, -1:, :] # [N, 1, state_dim]
windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :]
# Flatten to [N, state_obs_steps * state_dim] (same as prepare_state)
flattened = windows.reshape(len(batch_starts), -1)
running_stats.update(flattened)
stats = running_stats.get_statistics()
excluded_dims = int(mask_dim - relative_mask.sum())
logging.info(
f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): "
f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), "
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}"
)
return stats
+34
View File
@@ -41,6 +41,7 @@ from lerobot.datasets.compute_stats import (
aggregate_stats,
compute_episode_stats,
compute_relative_action_stats,
compute_relative_state_stats,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.io_utils import (
@@ -1544,6 +1545,10 @@ def recompute_stats(
relative_exclude_joints: list[str] | None = None,
chunk_size: int = 50,
num_workers: int = 0,
relative_state: bool = False,
relative_exclude_state_joints: list[str] | None = None,
state_obs_steps: int = 2,
derive_state_from_action: bool = False,
) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes.
@@ -1561,10 +1566,22 @@ def recompute_stats(
``policy.chunk_size``. Only used when ``relative_action=True``.
num_workers: Number of parallel threads for relative action stats computation.
Values 1 mean single-threaded. Only used when ``relative_action=True``.
relative_state: If True, compute observation.state stats in relative space
(multi-timestep offsets from current). This matches the normalization
the model sees during training with ``use_relative_state=True``.
relative_exclude_state_joints: State dim names to exclude from relative conversion.
state_obs_steps: Number of observation timesteps for relative state stats.
Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``.
derive_state_from_action: If True, compute relative state stats from the
action column instead of observation.state. Implies ``relative_state=True``
and ``state_obs_steps=2``.
Returns:
The same dataset with updated stats.
"""
if derive_state_from_action:
relative_state = True
state_obs_steps = 2
features = dataset.meta.features
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
numeric_features = {
@@ -1596,6 +1613,20 @@ def recompute_stats(
)
features_to_compute.pop(ACTION, None)
# When relative_state is enabled, compute state stats over the flattened
# multi-timestep relative representation (matching what the model sees).
relative_state_stats = None
if relative_state and (OBS_STATE in features or derive_state_from_action):
source_key = ACTION if derive_state_from_action else OBS_STATE
relative_state_stats = compute_relative_state_stats(
hf_dataset=dataset.hf_dataset,
features=features,
state_obs_steps=state_obs_steps,
exclude_joints=relative_exclude_state_joints,
source_key=source_key,
)
features_to_compute.pop(OBS_STATE, None)
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR
@@ -1632,6 +1663,9 @@ def recompute_stats(
if relative_action_stats is not None:
new_stats[ACTION] = relative_action_stats
if relative_state_stats is not None:
new_stats[OBS_STATE] = relative_state_stats
# Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
+5 -2
View File
@@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -52,12 +52,15 @@ def resolve_delta_timestamps(
returns `None` if the resulting dict is empty.
"""
delta_timestamps = {}
state_delta = getattr(cfg, "state_delta_indices", None)
for key in ds_meta.features:
if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
if key == OBS_STATE and state_delta is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in state_delta]
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
@@ -57,6 +57,28 @@ class PI0Config(PreTrainedConfig):
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Relative state (UMI-style relative proprioception): converts multi-timestep
# observation.state to offsets from the current timestep, providing velocity info.
# Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to
# max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim.
use_relative_state: bool = False
state_obs_steps: int = 1
relative_exclude_state_joints: list[str] = field(default_factory=list)
# Populated at runtime from dataset metadata by make_policy.
state_feature_names: list[str] | None = None
# Derive observation.state from the action column (UMI-style).
# When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1],
# DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state,
# and strips the extra timestep from the action chunk.
# Implies use_relative_state=True and state_obs_steps=2.
derive_state_from_action: bool = False
# Latency compensation: skip this many steps from the start of each predicted
# action chunk during inference. E.g. at 10Hz with ~200ms total latency,
# latency_skip_steps=2 compensates for the delay.
latency_skip_steps: int = 0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
@@ -106,6 +128,10 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self):
super().__post_init__()
if self.derive_state_from_action:
self.use_relative_state = True
self.state_obs_steps = 2
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
@@ -121,6 +147,13 @@ class PI0Config(PreTrainedConfig):
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
if self.use_relative_state and self.state_obs_steps < 2:
raise ValueError(
"use_relative_state requires state_obs_steps >= 2 "
f"(got {self.state_obs_steps}). Set state_obs_steps=2 for "
"UMI-style relative proprioception."
)
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
@@ -166,8 +199,16 @@ class PI0Config(PreTrainedConfig):
def observation_delta_indices(self) -> None:
return None
@property
def state_delta_indices(self) -> list[int] | None:
if self.state_obs_steps >= 2:
return list(range(-(self.state_obs_steps - 1), 1))
return None
@property
def action_delta_indices(self) -> list:
if self.derive_state_from_action:
return [-1] + list(range(self.chunk_size))
return list(range(self.chunk_size))
@property
+7 -3
View File
@@ -1230,8 +1230,11 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
"""Flatten multi-timestep state and pad to max_state_dim."""
state = batch[OBS_STATE]
if state.ndim == 3:
state = state.flatten(start_dim=1)
state = pad_vector(state, self.config.max_state_dim)
return state
def prepare_action(self, batch):
@@ -1250,7 +1253,8 @@ class PI0Policy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
skip = self.config.latency_skip_steps
actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))
+17 -1
View File
@@ -24,6 +24,7 @@ from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeriveStateFromActionStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -31,6 +32,7 @@ from lerobot.processor import (
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RelativeStateProcessorStep,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -128,13 +130,25 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
derive_state_step = DeriveStateFromActionStep(
enabled=getattr(config, "derive_state_from_action", False),
)
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
relative_state_step = RelativeStateProcessorStep(
enabled=getattr(config, "use_relative_state", False),
exclude_joints=getattr(config, "relative_exclude_state_joints", []),
state_names=getattr(config, "state_feature_names", None),
)
# Order: DeriveStateFromAction extracts state from the extended action chunk,
# then relative_action uses current state[t] for subtraction,
# then relative_state converts the multi-timestep state to offsets.
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
@@ -146,7 +160,9 @@ def make_pi0_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
derive_state_step,
relative_step,
relative_state_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
+6
View File
@@ -77,9 +77,12 @@ from .policy_robot_bridge import (
)
from .relative_action_processor import (
AbsoluteActionsProcessorStep,
DeriveStateFromActionStep,
RelativeActionsProcessorStep,
RelativeStateProcessorStep,
to_absolute_actions,
to_relative_actions,
to_relative_state,
)
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
@@ -107,7 +110,9 @@ __all__ = [
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"DeriveStateFromActionStep",
"RelativeActionsProcessorStep",
"RelativeStateProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep",
@@ -139,6 +144,7 @@ __all__ = [
"TruncatedProcessorStep",
"to_absolute_actions",
"to_relative_actions",
"to_relative_state",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]
@@ -30,10 +30,13 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
__all__ = [
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"DeriveStateFromActionStep",
"RelativeActionsProcessorStep",
"AbsoluteActionsProcessorStep",
"RelativeStateProcessorStep",
"to_relative_actions",
"to_absolute_actions",
"to_relative_state",
]
@@ -81,6 +84,41 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
return actions
@ProcessorStepRegistry.register("derive_state_from_action_processor")
@dataclass
class DeriveStateFromActionStep(ProcessorStep):
"""Derives 2-step observation.state from the action chunk (UMI-style).
Expects action with one extra leading timestep: [B, chunk_size+1, D]
from action_delta_indices = [-1, 0, 1, ..., chunk_size-1].
Extracts [action[t-1], action[t]] as state and strips the extra timestep.
No-op during inference (state comes from robot).
"""
enabled: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
action = transition.get(TransitionKey.ACTION)
if action is None or action.ndim < 3:
return transition
new_transition = transition.copy()
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
new_obs[OBS_STATE] = action[..., :2, :]
new_transition[TransitionKey.ACTION] = action[..., 1:, :]
new_transition[TransitionKey.OBSERVATION] = new_obs
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
@@ -124,7 +162,14 @@ class RelativeActionsProcessorStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
raw_state = observation.get(OBS_STATE) if observation else None
# When state_delta_indices loads multi-timestep state [B, n_obs, D],
# use only the current (last) timestep for relative action conversion.
if raw_state is not None:
state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state
else:
state = None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
@@ -155,6 +200,120 @@ class RelativeActionsProcessorStep(ProcessorStep):
return features
def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert multi-timestep absolute state to relative (offset from current timestep).
Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims.
The last (current) timestep becomes zeros for masked dims.
Args:
state: (..., n_obs, state_dim) last timestep is the reference (current).
mask: Which dims to convert. Can be shorter than state_dim.
"""
mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device)
dims = mask_t.shape[0]
current = state[..., -1:, :] # (..., 1, state_dim)
state = state.clone()
state[..., :dims] -= current[..., :dims] * mask_t
return state
@ProcessorStepRegistry.register("relative_state_processor")
@dataclass
class RelativeStateProcessorStep(ProcessorStep):
"""Converts observation.state to relative (offset from current timestep).
UMI-style relative proprioception: each state timestep is expressed as
an offset from the current EE pose, providing velocity information.
During training (multi-timestep input from ``state_delta_indices``):
``state[..., t, :] -= state[..., -1, :]`` subtract current from all.
During inference (single timestep): buffers the previous state and stacks
``[previous, current]`` before applying the relative conversion, producing
the same ``[n_obs, D]`` shape the model expects.
Attributes:
enabled: Whether to apply the relative conversion.
exclude_joints: Joint/dim names to keep absolute.
state_names: State dimension names from dataset metadata.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
state_names: list[str] | None = None
_previous_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, state_dim: int) -> list[bool]:
if not self.exclude_joints or self.state_names is None:
return [True] * state_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * state_dim
mask = []
for name in self.state_names[:state_dim]:
state_name = str(name).lower()
is_excluded = any(token == state_name or token in state_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < state_dim:
mask.extend([True] * (state_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
if state is None:
return transition
new_transition = transition.copy()
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
mask = self._build_mask(state.shape[-1])
if state.ndim >= 3:
# [B, n_obs, D] — multi-timestep (training with state_delta_indices)
relative = to_relative_state(state, mask)
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D]
elif state.ndim == 2:
# [B, D] — single timestep (inference): buffer previous and stack
current = state
if self._previous_state is None:
self._previous_state = current.clone()
prev = self._previous_state
if prev.device != current.device or prev.dtype != current.dtype:
prev = prev.to(device=current.device, dtype=current.dtype)
stacked = torch.stack([prev, current], dim=-2) # [B, 2, D]
relative = to_relative_state(stacked, mask)
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D]
self._previous_state = current.clone()
new_transition[TransitionKey.OBSERVATION] = new_obs
return new_transition
def reset(self) -> None:
"""Reset the state buffer. Call at episode boundaries during inference."""
self._previous_state = None
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"exclude_joints": self.exclude_joints,
"state_names": self.state_names,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
@@ -254,6 +254,10 @@ class RecomputeStatsConfig(OperationConfig):
relative_exclude_joints: list[str] | None = None
chunk_size: int = 50
num_workers: int = 0
relative_state: bool = False
relative_exclude_state_joints: list[str] | None = None
state_obs_steps: int = 2
derive_state_from_action: bool = False
@OperationConfig.register_subclass("info")
@@ -563,6 +567,14 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
f"exclude_joints={cfg.operation.relative_exclude_joints})"
)
if cfg.operation.relative_state:
logging.info(
f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, "
f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})"
)
if cfg.operation.derive_state_from_action:
logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)")
recompute_stats(
dataset,
@@ -571,6 +583,10 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
relative_exclude_joints=cfg.operation.relative_exclude_joints,
chunk_size=cfg.operation.chunk_size,
num_workers=cfg.operation.num_workers,
relative_state=cfg.operation.relative_state,
relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints,
state_obs_steps=cfg.operation.state_obs_steps,
derive_state_from_action=cfg.operation.derive_state_from_action,
)
logging.info(f"Stats written to {dataset.root}")