mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
refactor to use relative state
This commit is contained in:
@@ -202,11 +202,22 @@ Here is how the different processors compose. Each arrow is a processor step, an
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
Representation │ Absolute ←────→ Relative │
|
||||
State Derivation │ Action column ────→ State + Action │
|
||||
│ DeriveStateFromActionStep (pre only) │
|
||||
│ (UMI-style: state from action chunk) │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
Action Repr. │ Absolute ←────→ Relative │
|
||||
│ RelativeActionsProcessorStep (pre) │
|
||||
│ AbsoluteActionsProcessorStep (post) │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
State Repr. │ Absolute ────→ Relative │
|
||||
│ RelativeStateProcessorStep (pre only) │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
Normalization │ Raw ←────→ Normalized │
|
||||
│ NormalizerProcessorStep (pre) │
|
||||
@@ -216,6 +227,10 @@ Here is how the different processors compose. Each arrow is a processor step, an
|
||||
|
||||
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
|
||||
|
||||
With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output).
|
||||
|
||||
With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details.
|
||||
|
||||
## References
|
||||
|
||||
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
|
||||
|
||||
@@ -4,16 +4,13 @@ This guide explains how to prepare a UMI-collected dataset for training a pi0 po
|
||||
|
||||
**What we will do:**
|
||||
|
||||
1. How to add `observation.state` to an existing UMI LeRobot dataset.
|
||||
2. How to train pi0 with `use_relative_actions=True`.
|
||||
3. How to evaluate the trained policy on a real robot.
|
||||
1. Recompute dataset statistics for relative actions and state.
|
||||
2. Train pi0 with `derive_state_from_action=true` (full UMI pipeline).
|
||||
3. Evaluate the trained policy on a real robot.
|
||||
|
||||
## Background
|
||||
|
||||
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need two additions:
|
||||
|
||||
1. **`observation.state`** — the current EE pose the policy conditions on.
|
||||
2. **Relative action statistics** — so the normalizer sees `(action − state)` distributions.
|
||||
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need **relative action statistics** — so the normalizer sees `(action − state)` distributions.
|
||||
|
||||
### Why relative actions?
|
||||
|
||||
@@ -25,72 +22,39 @@ relative_action[i] = absolute_action[t + i] − state[t]
|
||||
|
||||
This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison.
|
||||
|
||||
## State-Action Offset
|
||||
### Full UMI mode: `derive_state_from_action`
|
||||
|
||||
UMI SLAM produces a single trajectory of EE poses stored as `action`. We derive `observation.state` from the same trajectory with a configurable offset:
|
||||
When `derive_state_from_action=true`, pi0 automatically derives `observation.state` from the action column on the fly — no separate state column or dataset conversion step needed. Under the hood:
|
||||
|
||||
```
|
||||
state[t] = action[t - offset]
|
||||
```
|
||||
- `action_delta_indices` extends to `[-1, 0, 1, ..., 49]` (one extra leading timestep).
|
||||
- `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk.
|
||||
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
|
||||
- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens.
|
||||
|
||||
| Offset | `state[t]` | Meaning |
|
||||
| ------ | ------------- | ---------------------------------------------------------------- |
|
||||
| 0 | `action[t]` | State and action are the same pose at time t |
|
||||
| 1 | `action[t-1]` | State is the previous action — where the gripper already arrived |
|
||||
This single flag implies `use_relative_state=true` and `state_obs_steps=2`.
|
||||
|
||||
An offset of 1 is the typical UMI convention: at decision time the "current state" is where the gripper _already is_ (the result of the previous command), and the action is where it should go next. At episode boundaries where `t < offset`, we clamp to `action[0]`.
|
||||
During **inference**, state comes from the robot (via FK), so `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically.
|
||||
|
||||
## Step 1: Add `observation.state`
|
||||
## Step 1: Recompute Stats
|
||||
|
||||
pi0 with `use_relative_actions=True` needs `observation.state` in the dataset to compute `action - state` on the fly. The script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` adds it. Edit the constants at the top:
|
||||
|
||||
```python
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Option A: Copy an existing feature as observation.state
|
||||
STATE_SOURCE_FEATURE = "observation.joints" # or "observation.pose", etc.
|
||||
|
||||
# Option B: Derive from action with offset (set STATE_SOURCE_FEATURE = None)
|
||||
STATE_SOURCE_FEATURE = None
|
||||
STATE_ACTION_OFFSET = 1
|
||||
```
|
||||
|
||||
**Choosing the state source:**
|
||||
|
||||
- If your dataset already has a feature in the same space as `action` (e.g. `observation.joints` for joint-space actions, or `observation.pose` for EE-space actions), set `STATE_SOURCE_FEATURE` to copy it.
|
||||
- If your dataset only has a single trajectory (like raw UMI EE data where action = the EE poses), set `STATE_SOURCE_FEATURE = None` and use `STATE_ACTION_OFFSET` to derive state from the action column with a time offset.
|
||||
|
||||
`observation.state` **must have the same shape as `action`** — the relative conversion computes `action - state` element-wise.
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
python examples/umi_pi0_relative_ee/convert_umi_dataset.py
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If your dataset already has `observation.state`, the script exits early — nothing to do.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Step 2: Recompute Relative Action Stats
|
||||
|
||||
Use the built-in CLI to recompute dataset statistics in relative space:
|
||||
Use the built-in CLI to recompute dataset statistics for relative actions and derive-state-from-action:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id <your_dataset> \
|
||||
--operation.type recompute_stats \
|
||||
--operation.relative_action true \
|
||||
--operation.derive_state_from_action true \
|
||||
--operation.chunk_size 50 \
|
||||
--operation.relative_exclude_joints "['gripper']" \
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
The `derive_state_from_action` flag tells `recompute_stats` to read from the action column (instead of `observation.state`) when computing relative state stats. It automatically enables `relative_state=true` and `state_obs_steps=2`.
|
||||
|
||||
The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative.
|
||||
|
||||
## Step 3: Train
|
||||
## Step 2: Train
|
||||
|
||||
No custom training script is needed — standard `lerobot-train` handles everything:
|
||||
|
||||
@@ -99,19 +63,26 @@ lerobot-train \
|
||||
--dataset.repo_id=<hf_username>/<dataset_repo_id> \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=lerobot/pi0_base \
|
||||
--policy.derive_state_from_action=true \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
`derive_state_from_action=true` auto-enables `use_relative_state=true` and `state_obs_steps=2`.
|
||||
|
||||
Under the hood, the training pipeline:
|
||||
|
||||
- Loads relative action stats from the dataset's `meta/stats.json`.
|
||||
- Configures `RelativeActionsProcessorStep` in the preprocessor (absolute → relative before normalization).
|
||||
- The model trains on normalized relative action values.
|
||||
- Loads relative action stats and relative state stats from the dataset's `meta/stats.json`.
|
||||
- Extends `action_delta_indices` to `[-1, 0, 1, ..., 49]` to load one extra leading timestep.
|
||||
- `DeriveStateFromActionStep` extracts the 2-step state from the action chunk and strips the extra timestep.
|
||||
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
|
||||
- `RelativeStateProcessorStep` converts the 2-step state to relative offsets from the current timestep, then flattens.
|
||||
- `NormalizerProcessorStep` normalizes everything.
|
||||
- The model trains on normalized relative values.
|
||||
|
||||
See the [pi0 documentation](pi0) for all available training options.
|
||||
|
||||
## Step 4: Evaluate
|
||||
## Step 3: Evaluate
|
||||
|
||||
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space):
|
||||
|
||||
@@ -121,10 +92,20 @@ python examples/umi_pi0_relative_ee/evaluate.py
|
||||
|
||||
Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file.
|
||||
|
||||
### Latency compensation
|
||||
|
||||
For real robot deployment, you may want to skip the first few steps of each predicted action chunk to compensate for system latency. Set `LATENCY_SKIP_STEPS` in the evaluate script:
|
||||
|
||||
```python
|
||||
LATENCY_SKIP_STEPS = 0 # ceil(total_latency_ms / (1000 / FPS))
|
||||
```
|
||||
|
||||
For example, at 10Hz with ~200ms total latency, set `LATENCY_SKIP_STEPS = 2`.
|
||||
|
||||
The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed:
|
||||
|
||||
1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`.
|
||||
2. **Preprocessor** — `RelativeActionsProcessorStep` caches the raw `observation.state`, then `NormalizerProcessorStep` normalizes everything.
|
||||
2. **Preprocessor** — `DeriveStateFromActionStep` is a no-op (state comes from robot). `RelativeStateProcessorStep` buffers previous state, stacks, and converts to relative. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes.
|
||||
3. **pi0 inference** — The model predicts a normalized relative action chunk.
|
||||
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets.
|
||||
5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands.
|
||||
@@ -132,14 +113,22 @@ The inference flow uses pi0's built-in processor pipeline — no custom wrappers
|
||||
## How the Pieces Fit Together
|
||||
|
||||
```
|
||||
Training:
|
||||
dataset (absolute EE) → RelativeActionsProcessorStep → NormalizerProcessorStep → pi0 model
|
||||
Training (full UMI mode: derive_state_from_action=true):
|
||||
DataLoader (action: B,51,D)
|
||||
→ DeriveStateFromActionStep (state = action[:,:2,:], action = action[:,1:,:])
|
||||
→ RelativeActionsProcessorStep (action -= state[:,-1,:])
|
||||
→ RelativeStateProcessorStep (state offsets from current, flatten → B,2*D)
|
||||
→ NormalizerProcessorStep → pi0 model
|
||||
|
||||
Inference:
|
||||
robot joints → FK → observation.state (absolute EE)
|
||||
↓
|
||||
DeriveStateFromActionStep (no-op)
|
||||
↓
|
||||
RelativeActionsProcessorStep (caches state)
|
||||
↓
|
||||
RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens)
|
||||
↓
|
||||
NormalizerProcessorStep → pi0 model → relative action chunk
|
||||
↓
|
||||
UnnormalizerProcessorStep
|
||||
@@ -149,6 +138,31 @@ Inference:
|
||||
IK → joint targets → robot
|
||||
```
|
||||
|
||||
## Manual mode (without derive_state_from_action)
|
||||
|
||||
If your dataset already has `observation.state` (or you want to add it separately), you can skip `derive_state_from_action` and use relative actions + relative state independently:
|
||||
|
||||
```bash
|
||||
# Recompute stats
|
||||
lerobot-edit-dataset \
|
||||
--repo_id <your_dataset> \
|
||||
--operation.type recompute_stats \
|
||||
--operation.relative_action true \
|
||||
--operation.relative_state true \
|
||||
--operation.state_obs_steps 2 \
|
||||
--operation.chunk_size 50 \
|
||||
--operation.relative_exclude_joints "['gripper']"
|
||||
|
||||
# Train
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<your_dataset> \
|
||||
--policy.type=pi0 \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.use_relative_state=true \
|
||||
--policy.state_obs_steps=2 \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Add ``observation.state`` to an existing LeRobot dataset.
|
||||
|
||||
pi0 uses ``observation.state`` as its proprioceptive input AND for
|
||||
relative action conversion (action − state). This script creates
|
||||
``observation.state`` by concatenating one or more existing features.
|
||||
|
||||
Ordering matters: the features whose dimensions correspond to ``action``
|
||||
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
|
||||
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
|
||||
pose) are appended after and are seen by the model but not used for
|
||||
relative conversion.
|
||||
|
||||
Example: action = [proximal, distal], and we want the model to also see
|
||||
the EE pose:
|
||||
|
||||
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
|
||||
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
|
||||
|
||||
The relative conversion uses state[:2] = [proximal, distal] to subtract
|
||||
from action[:2], and the model sees all 8 dimensions.
|
||||
|
||||
After running this script, recompute relative action stats:
|
||||
|
||||
lerobot-edit-dataset \\
|
||||
--repo_id <your_dataset> \\
|
||||
--operation.type recompute_stats \\
|
||||
--operation.relative_action true \\
|
||||
--operation.chunk_size 50 \\
|
||||
--operation.relative_exclude_joints "[]" \\
|
||||
--push_to_hub true
|
||||
|
||||
Usage:
|
||||
python convert_umi_dataset.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_DATASET_ID = ""
|
||||
|
||||
# Output repo ID. Set to None for default "<input>_modified".
|
||||
OUTPUT_REPO_ID: str | None = None
|
||||
|
||||
# Features to concatenate into observation.state. Order matters:
|
||||
# action-matching features FIRST, then extra proprioception.
|
||||
# Set to a single string to copy one feature directly.
|
||||
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
|
||||
|
||||
# Only used when STATE_SOURCE_FEATURES is None:
|
||||
# derive state from action with a per-episode offset.
|
||||
STATE_ACTION_OFFSET = 1
|
||||
|
||||
# Push the augmented dataset to the Hugging Face Hub.
|
||||
PUSH_TO_HUB = True
|
||||
|
||||
|
||||
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
|
||||
hf = dataset.hf_dataset
|
||||
ep = np.array(hf["episode_index"])
|
||||
fr = np.array(hf["frame_index"])
|
||||
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
|
||||
|
||||
|
||||
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
|
||||
"""Concatenate multiple features into observation.state."""
|
||||
hf = dataset.hf_dataset
|
||||
key_to_global = _build_global_index(dataset)
|
||||
|
||||
columns = [hf[feat] for feat in source_features]
|
||||
|
||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||
g = key_to_global[(ep_idx, frame_idx)]
|
||||
parts = []
|
||||
for col in columns:
|
||||
val = col[g]
|
||||
if hasattr(val, "tolist"):
|
||||
flat = val.tolist()
|
||||
if isinstance(flat, list):
|
||||
parts.extend(flat)
|
||||
else:
|
||||
parts.append(flat)
|
||||
elif isinstance(val, list):
|
||||
parts.extend(val)
|
||||
else:
|
||||
parts.append(float(val))
|
||||
return parts
|
||||
|
||||
return _get_state
|
||||
|
||||
|
||||
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
|
||||
"""Derive state from action with a per-episode offset."""
|
||||
hf = dataset.hf_dataset
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
frame_indices = np.array(hf["frame_index"])
|
||||
|
||||
ep_sorted: dict[int, list[tuple[int, int]]] = {}
|
||||
for ep_idx in np.unique(episode_indices):
|
||||
ep_mask = episode_indices == ep_idx
|
||||
ep_globals = np.where(ep_mask)[0]
|
||||
ep_frames = frame_indices[ep_globals]
|
||||
order = np.argsort(ep_frames)
|
||||
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
|
||||
|
||||
ep_frame_to_local: dict[int, dict[int, int]] = {}
|
||||
for ep_idx, sorted_list in ep_sorted.items():
|
||||
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
|
||||
|
||||
actions = hf["action"]
|
||||
|
||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||
local_t = ep_frame_to_local[ep_idx][frame_idx]
|
||||
source_local = max(0, local_t - offset)
|
||||
_, source_global = ep_sorted[ep_idx][source_local]
|
||||
return actions[source_global]
|
||||
|
||||
return _get_state
|
||||
|
||||
|
||||
def main():
|
||||
logger.info(f"Loading dataset {HF_DATASET_ID}")
|
||||
dataset = LeRobotDataset(HF_DATASET_ID)
|
||||
|
||||
if "observation.state" in dataset.features:
|
||||
logger.info("observation.state already exists — nothing to do")
|
||||
return
|
||||
|
||||
action_meta = dataset.features["action"]
|
||||
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
|
||||
|
||||
if STATE_SOURCE_FEATURES is not None:
|
||||
source_list = (
|
||||
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
|
||||
)
|
||||
for feat in source_list:
|
||||
if feat not in dataset.features:
|
||||
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
|
||||
|
||||
# Compute combined shape and names
|
||||
total_dim = 0
|
||||
all_names = []
|
||||
for feat in source_list:
|
||||
meta = dataset.features[feat]
|
||||
total_dim += meta["shape"][0]
|
||||
names = meta.get("names")
|
||||
if names:
|
||||
all_names.extend(names)
|
||||
|
||||
logger.info(
|
||||
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
|
||||
)
|
||||
state_fn = _build_state_from_features(dataset, source_list)
|
||||
state_feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": [total_dim],
|
||||
"names": all_names or None,
|
||||
}
|
||||
else:
|
||||
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
|
||||
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
|
||||
state_feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": list(action_meta["shape"]),
|
||||
"names": action_meta.get("names"),
|
||||
}
|
||||
|
||||
augmented = add_features(
|
||||
dataset,
|
||||
features={"observation.state": (state_fn, state_feature_info)},
|
||||
repo_id=OUTPUT_REPO_ID,
|
||||
)
|
||||
logger.info("observation.state added")
|
||||
|
||||
if PUSH_TO_HUB:
|
||||
logger.info(f"Pushing to Hub: {augmented.repo_id}")
|
||||
augmented.push_to_hub()
|
||||
|
||||
logger.info(
|
||||
f"Done. Dataset at: {augmented.root}\n"
|
||||
"Now recompute relative action stats:\n"
|
||||
" lerobot-edit-dataset \\\n"
|
||||
f" --repo_id {augmented.repo_id} \\\n"
|
||||
" --operation.type recompute_stats \\\n"
|
||||
" --operation.relative_action true \\\n"
|
||||
" --operation.chunk_size 50 \\\n"
|
||||
' --operation.relative_exclude_joints "[]" \\\n'
|
||||
" --push_to_hub true"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -17,17 +17,20 @@
|
||||
"""
|
||||
Inference script for a pi0 model trained with **relative EE actions**.
|
||||
|
||||
This uses the built-in ``RelativeActionsProcessorStep`` and
|
||||
``AbsoluteActionsProcessorStep`` that are already wired into pi0's
|
||||
processor pipeline when ``use_relative_actions=True``.
|
||||
This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
|
||||
``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
|
||||
``RelativeStateProcessorStep`` that are already wired into pi0's processor
|
||||
pipeline.
|
||||
|
||||
The inference loop:
|
||||
1. Reads joint positions from the robot.
|
||||
2. Converts to EE pose via forward kinematics (FK).
|
||||
This produces ``observation.state`` with the current EE pose.
|
||||
3. The pi0 preprocessor:
|
||||
a) ``RelativeActionsProcessorStep`` caches the raw state.
|
||||
b) ``NormalizerProcessorStep`` normalizes state and actions.
|
||||
a) ``DeriveStateFromActionStep`` — no-op (state comes from robot).
|
||||
b) ``RelativeActionsProcessorStep`` caches the raw state.
|
||||
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
|
||||
d) ``NormalizerProcessorStep`` normalizes state and actions.
|
||||
4. pi0 predicts relative action chunk.
|
||||
5. The pi0 postprocessor:
|
||||
a) ``UnnormalizerProcessorStep`` unnormalizes.
|
||||
@@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.processor import (
|
||||
RelativeStateProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
@@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Latency compensation: skip this many steps from the start of each predicted
|
||||
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
|
||||
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
|
||||
LATENCY_SKIP_STEPS = 0
|
||||
|
||||
# EE feature keys produced by ForwardKinematicsJointsToEE
|
||||
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
|
||||
@@ -94,6 +103,7 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
|
||||
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
@@ -151,9 +161,8 @@ def main():
|
||||
|
||||
# Build pre/post processors from the trained model.
|
||||
# The pi0 processor pipeline already includes:
|
||||
# pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep
|
||||
# pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep
|
||||
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
|
||||
# These handle the relative ↔ absolute conversion automatically.
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
@@ -161,6 +170,9 @@ def main():
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Find the relative state step (if present) so we can reset its buffer between episodes.
|
||||
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
|
||||
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
@@ -174,6 +186,10 @@ def main():
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Reset relative state buffer so velocity is zero at episode start
|
||||
for step in _relative_state_steps:
|
||||
step.reset()
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
|
||||
@@ -115,6 +115,17 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def state_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
"""Delta indices specifically for observation.state.
|
||||
|
||||
When not None, overrides ``observation_delta_indices`` for the
|
||||
``observation.state`` key only. Useful for loading state history
|
||||
(e.g. ``[-1, 0]`` for UMI-style relative proprioception) without
|
||||
also loading multiple image timesteps.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -767,3 +767,94 @@ def compute_relative_action_stats(
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def compute_relative_state_stats(
|
||||
hf_dataset,
|
||||
features: dict,
|
||||
state_obs_steps: int = 2,
|
||||
exclude_joints: list[str] | None = None,
|
||||
source_key: str = OBS_STATE,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute normalization statistics for observation.state after relative conversion.
|
||||
|
||||
For UMI-style relative proprioception with ``state_obs_steps`` timesteps,
|
||||
each state observation becomes a stack of offsets from the current timestep:
|
||||
``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``.
|
||||
|
||||
The stats are computed over the flattened ``[state_obs_steps * state_dim]``
|
||||
vector that the model actually sees after ``prepare_state`` flattening.
|
||||
|
||||
Args:
|
||||
hf_dataset: The HuggingFace dataset with the source column and
|
||||
"episode_index" columns.
|
||||
features: Dataset feature metadata.
|
||||
state_obs_steps: Number of observation timesteps (must be >= 2).
|
||||
exclude_joints: State dimension names to keep absolute.
|
||||
source_key: Column to read data from. Defaults to "observation.state".
|
||||
When ``derive_state_from_action=True``, pass ``ACTION`` to read
|
||||
from the action column instead.
|
||||
|
||||
Returns:
|
||||
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
|
||||
"""
|
||||
from lerobot.processor.relative_action_processor import RelativeStateProcessorStep
|
||||
|
||||
if exclude_joints is None:
|
||||
exclude_joints = []
|
||||
|
||||
state_dim = features[source_key]["shape"][0]
|
||||
state_names = features.get(source_key, {}).get("names")
|
||||
mask_step = RelativeStateProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=exclude_joints,
|
||||
state_names=state_names,
|
||||
)
|
||||
relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32)
|
||||
|
||||
logging.info(f"Loading data from '{source_key}' for relative state stats...")
|
||||
all_states = np.array(hf_dataset[source_key], dtype=np.float32)
|
||||
episode_indices = np.array(hf_dataset["episode_index"])
|
||||
|
||||
# Build all valid windows of length state_obs_steps within each episode
|
||||
n = len(all_states)
|
||||
if n < state_obs_steps:
|
||||
raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}")
|
||||
|
||||
max_start = n - state_obs_steps
|
||||
starts = np.arange(max_start + 1)
|
||||
valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1]
|
||||
valid_starts = starts[valid]
|
||||
|
||||
if len(valid_starts) == 0:
|
||||
raise RuntimeError("No valid state windows found within single episodes")
|
||||
|
||||
offsets = np.arange(state_obs_steps)
|
||||
mask_dim = len(relative_mask)
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
batch_size = 50_000
|
||||
for i in range(0, len(valid_starts), batch_size):
|
||||
batch_starts = valid_starts[i : i + batch_size]
|
||||
frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps]
|
||||
windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim]
|
||||
|
||||
# Subtract current (last) timestep from all timesteps for masked dims
|
||||
current = windows[:, -1:, :] # [N, 1, state_dim]
|
||||
windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :]
|
||||
|
||||
# Flatten to [N, state_obs_steps * state_dim] (same as prepare_state)
|
||||
flattened = windows.reshape(len(batch_starts), -1)
|
||||
running_stats.update(flattened)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
excluded_dims = int(mask_dim - relative_mask.sum())
|
||||
logging.info(
|
||||
f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): "
|
||||
f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
@@ -41,6 +41,7 @@ from lerobot.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
compute_episode_stats,
|
||||
compute_relative_action_stats,
|
||||
compute_relative_state_stats,
|
||||
)
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.io_utils import (
|
||||
@@ -1544,6 +1545,10 @@ def recompute_stats(
|
||||
relative_exclude_joints: list[str] | None = None,
|
||||
chunk_size: int = 50,
|
||||
num_workers: int = 0,
|
||||
relative_state: bool = False,
|
||||
relative_exclude_state_joints: list[str] | None = None,
|
||||
state_obs_steps: int = 2,
|
||||
derive_state_from_action: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
@@ -1561,10 +1566,22 @@ def recompute_stats(
|
||||
``policy.chunk_size``. Only used when ``relative_action=True``.
|
||||
num_workers: Number of parallel threads for relative action stats computation.
|
||||
Values ≤1 mean single-threaded. Only used when ``relative_action=True``.
|
||||
relative_state: If True, compute observation.state stats in relative space
|
||||
(multi-timestep offsets from current). This matches the normalization
|
||||
the model sees during training with ``use_relative_state=True``.
|
||||
relative_exclude_state_joints: State dim names to exclude from relative conversion.
|
||||
state_obs_steps: Number of observation timesteps for relative state stats.
|
||||
Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``.
|
||||
derive_state_from_action: If True, compute relative state stats from the
|
||||
action column instead of observation.state. Implies ``relative_state=True``
|
||||
and ``state_obs_steps=2``.
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
if derive_state_from_action:
|
||||
relative_state = True
|
||||
state_obs_steps = 2
|
||||
features = dataset.meta.features
|
||||
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
|
||||
numeric_features = {
|
||||
@@ -1596,6 +1613,20 @@ def recompute_stats(
|
||||
)
|
||||
features_to_compute.pop(ACTION, None)
|
||||
|
||||
# When relative_state is enabled, compute state stats over the flattened
|
||||
# multi-timestep relative representation (matching what the model sees).
|
||||
relative_state_stats = None
|
||||
if relative_state and (OBS_STATE in features or derive_state_from_action):
|
||||
source_key = ACTION if derive_state_from_action else OBS_STATE
|
||||
relative_state_stats = compute_relative_state_stats(
|
||||
hf_dataset=dataset.hf_dataset,
|
||||
features=features,
|
||||
state_obs_steps=state_obs_steps,
|
||||
exclude_joints=relative_exclude_state_joints,
|
||||
source_key=source_key,
|
||||
)
|
||||
features_to_compute.pop(OBS_STATE, None)
|
||||
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
@@ -1632,6 +1663,9 @@ def recompute_stats(
|
||||
if relative_action_stats is not None:
|
||||
new_stats[ACTION] = relative_action_stats
|
||||
|
||||
if relative_state_stats is not None:
|
||||
new_stats[OBS_STATE] = relative_state_stats
|
||||
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -52,12 +52,15 @@ def resolve_delta_timestamps(
|
||||
returns `None` if the resulting dict is empty.
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
state_delta = getattr(cfg, "state_delta_indices", None)
|
||||
for key in ds_meta.features:
|
||||
if key == REWARD and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
if key == OBS_STATE and state_delta is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in state_delta]
|
||||
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -57,6 +57,28 @@ class PI0Config(PreTrainedConfig):
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Relative state (UMI-style relative proprioception): converts multi-timestep
|
||||
# observation.state to offsets from the current timestep, providing velocity info.
|
||||
# Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to
|
||||
# max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim.
|
||||
use_relative_state: bool = False
|
||||
state_obs_steps: int = 1
|
||||
relative_exclude_state_joints: list[str] = field(default_factory=list)
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
state_feature_names: list[str] | None = None
|
||||
|
||||
# Derive observation.state from the action column (UMI-style).
|
||||
# When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1],
|
||||
# DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state,
|
||||
# and strips the extra timestep from the action chunk.
|
||||
# Implies use_relative_state=True and state_obs_steps=2.
|
||||
derive_state_from_action: bool = False
|
||||
|
||||
# Latency compensation: skip this many steps from the start of each predicted
|
||||
# action chunk during inference. E.g. at 10Hz with ~200ms total latency,
|
||||
# latency_skip_steps=2 compensates for the delay.
|
||||
latency_skip_steps: int = 0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
@@ -106,6 +128,10 @@ class PI0Config(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.derive_state_from_action:
|
||||
self.use_relative_state = True
|
||||
self.state_obs_steps = 2
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
@@ -121,6 +147,13 @@ class PI0Config(PreTrainedConfig):
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
if self.use_relative_state and self.state_obs_steps < 2:
|
||||
raise ValueError(
|
||||
"use_relative_state requires state_obs_steps >= 2 "
|
||||
f"(got {self.state_obs_steps}). Set state_obs_steps=2 for "
|
||||
"UMI-style relative proprioception."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
@@ -166,8 +199,16 @@ class PI0Config(PreTrainedConfig):
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def state_delta_indices(self) -> list[int] | None:
|
||||
if self.state_obs_steps >= 2:
|
||||
return list(range(-(self.state_obs_steps - 1), 1))
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
if self.derive_state_from_action:
|
||||
return [-1] + list(range(self.chunk_size))
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
|
||||
@@ -1230,8 +1230,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return images, img_masks
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
"""Flatten multi-timestep state and pad to max_state_dim."""
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim == 3:
|
||||
state = state.flatten(start_dim=1)
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
@@ -1250,7 +1253,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
skip = self.config.latency_skip_steps
|
||||
actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps]
|
||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeriveStateFromActionStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -31,6 +32,7 @@ from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RelativeActionsProcessorStep,
|
||||
RelativeStateProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
@@ -128,13 +130,25 @@ def make_pi0_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
derive_state_step = DeriveStateFromActionStep(
|
||||
enabled=getattr(config, "derive_state_from_action", False),
|
||||
)
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
|
||||
relative_state_step = RelativeStateProcessorStep(
|
||||
enabled=getattr(config, "use_relative_state", False),
|
||||
exclude_joints=getattr(config, "relative_exclude_state_joints", []),
|
||||
state_names=getattr(config, "state_feature_names", None),
|
||||
)
|
||||
|
||||
# Order: DeriveStateFromAction extracts state from the extended action chunk,
|
||||
# then relative_action uses current state[t] for subtraction,
|
||||
# then relative_state converts the multi-timestep state to offsets.
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -146,7 +160,9 @@ def make_pi0_pre_post_processors(
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
derive_state_step,
|
||||
relative_step,
|
||||
relative_state_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
|
||||
@@ -77,9 +77,12 @@ from .policy_robot_bridge import (
|
||||
)
|
||||
from .relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeriveStateFromActionStep,
|
||||
RelativeActionsProcessorStep,
|
||||
RelativeStateProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_relative_actions,
|
||||
to_relative_state,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
@@ -107,7 +110,9 @@ __all__ = [
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"DeriveStateFromActionStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"RelativeStateProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -139,6 +144,7 @@ __all__ = [
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_relative_actions",
|
||||
"to_relative_state",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -30,10 +30,13 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
__all__ = [
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"DeriveStateFromActionStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"RelativeStateProcessorStep",
|
||||
"to_relative_actions",
|
||||
"to_absolute_actions",
|
||||
"to_relative_state",
|
||||
]
|
||||
|
||||
|
||||
@@ -81,6 +84,41 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("derive_state_from_action_processor")
|
||||
@dataclass
|
||||
class DeriveStateFromActionStep(ProcessorStep):
|
||||
"""Derives 2-step observation.state from the action chunk (UMI-style).
|
||||
|
||||
Expects action with one extra leading timestep: [B, chunk_size+1, D]
|
||||
from action_delta_indices = [-1, 0, 1, ..., chunk_size-1].
|
||||
Extracts [action[t-1], action[t]] as state and strips the extra timestep.
|
||||
No-op during inference (state comes from robot).
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None or action.ndim < 3:
|
||||
return transition
|
||||
new_transition = transition.copy()
|
||||
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
|
||||
new_obs[OBS_STATE] = action[..., :2, :]
|
||||
new_transition[TransitionKey.ACTION] = action[..., 1:, :]
|
||||
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
@@ -124,7 +162,14 @@ class RelativeActionsProcessorStep(ProcessorStep):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
raw_state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# When state_delta_indices loads multi-timestep state [B, n_obs, D],
|
||||
# use only the current (last) timestep for relative action conversion.
|
||||
if raw_state is not None:
|
||||
state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state
|
||||
else:
|
||||
state = None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
@@ -155,6 +200,120 @@ class RelativeActionsProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert multi-timestep absolute state to relative (offset from current timestep).
|
||||
|
||||
Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims.
|
||||
The last (current) timestep becomes zeros for masked dims.
|
||||
|
||||
Args:
|
||||
state: (..., n_obs, state_dim) — last timestep is the reference (current).
|
||||
mask: Which dims to convert. Can be shorter than state_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device)
|
||||
dims = mask_t.shape[0]
|
||||
current = state[..., -1:, :] # (..., 1, state_dim)
|
||||
state = state.clone()
|
||||
state[..., :dims] -= current[..., :dims] * mask_t
|
||||
return state
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("relative_state_processor")
|
||||
@dataclass
|
||||
class RelativeStateProcessorStep(ProcessorStep):
|
||||
"""Converts observation.state to relative (offset from current timestep).
|
||||
|
||||
UMI-style relative proprioception: each state timestep is expressed as
|
||||
an offset from the current EE pose, providing velocity information.
|
||||
|
||||
During training (multi-timestep input from ``state_delta_indices``):
|
||||
``state[..., t, :] -= state[..., -1, :]`` — subtract current from all.
|
||||
|
||||
During inference (single timestep): buffers the previous state and stacks
|
||||
``[previous, current]`` before applying the relative conversion, producing
|
||||
the same ``[n_obs, D]`` shape the model expects.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the relative conversion.
|
||||
exclude_joints: Joint/dim names to keep absolute.
|
||||
state_names: State dimension names from dataset metadata.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
state_names: list[str] | None = None
|
||||
_previous_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, state_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.state_names is None:
|
||||
return [True] * state_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * state_dim
|
||||
|
||||
mask = []
|
||||
for name in self.state_names[:state_dim]:
|
||||
state_name = str(name).lower()
|
||||
is_excluded = any(token == state_name or token in state_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < state_dim:
|
||||
mask.extend([True] * (state_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
if state is None:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
|
||||
mask = self._build_mask(state.shape[-1])
|
||||
|
||||
if state.ndim >= 3:
|
||||
# [B, n_obs, D] — multi-timestep (training with state_delta_indices)
|
||||
relative = to_relative_state(state, mask)
|
||||
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D]
|
||||
elif state.ndim == 2:
|
||||
# [B, D] — single timestep (inference): buffer previous and stack
|
||||
current = state
|
||||
if self._previous_state is None:
|
||||
self._previous_state = current.clone()
|
||||
prev = self._previous_state
|
||||
if prev.device != current.device or prev.dtype != current.dtype:
|
||||
prev = prev.to(device=current.device, dtype=current.dtype)
|
||||
stacked = torch.stack([prev, current], dim=-2) # [B, 2, D]
|
||||
relative = to_relative_state(stacked, mask)
|
||||
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D]
|
||||
self._previous_state = current.clone()
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||
return new_transition
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the state buffer. Call at episode boundaries during inference."""
|
||||
self._previous_state = None
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"exclude_joints": self.exclude_joints,
|
||||
"state_names": self.state_names,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -254,6 +254,10 @@ class RecomputeStatsConfig(OperationConfig):
|
||||
relative_exclude_joints: list[str] | None = None
|
||||
chunk_size: int = 50
|
||||
num_workers: int = 0
|
||||
relative_state: bool = False
|
||||
relative_exclude_state_joints: list[str] | None = None
|
||||
state_obs_steps: int = 2
|
||||
derive_state_from_action: bool = False
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@@ -563,6 +567,14 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
|
||||
f"exclude_joints={cfg.operation.relative_exclude_joints})"
|
||||
)
|
||||
if cfg.operation.relative_state:
|
||||
logging.info(
|
||||
f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, "
|
||||
f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})"
|
||||
)
|
||||
|
||||
if cfg.operation.derive_state_from_action:
|
||||
logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)")
|
||||
|
||||
recompute_stats(
|
||||
dataset,
|
||||
@@ -571,6 +583,10 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
relative_exclude_joints=cfg.operation.relative_exclude_joints,
|
||||
chunk_size=cfg.operation.chunk_size,
|
||||
num_workers=cfg.operation.num_workers,
|
||||
relative_state=cfg.operation.relative_state,
|
||||
relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints,
|
||||
state_obs_steps=cfg.operation.state_obs_steps,
|
||||
derive_state_from_action=cfg.operation.derive_state_from_action,
|
||||
)
|
||||
|
||||
logging.info(f"Stats written to {dataset.root}")
|
||||
|
||||
Reference in New Issue
Block a user