mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
refactor to use relative state
This commit is contained in:
@@ -202,11 +202,22 @@ Here is how the different processors compose. Each arrow is a processor step, an
|
|||||||
└─────────────────────────────────────────┘
|
└─────────────────────────────────────────┘
|
||||||
|
|
||||||
┌─────────────────────────────────────────┐
|
┌─────────────────────────────────────────┐
|
||||||
Representation │ Absolute ←────→ Relative │
|
State Derivation │ Action column ────→ State + Action │
|
||||||
|
│ DeriveStateFromActionStep (pre only) │
|
||||||
|
│ (UMI-style: state from action chunk) │
|
||||||
|
└─────────────────────────────────────────┘
|
||||||
|
|
||||||
|
┌─────────────────────────────────────────┐
|
||||||
|
Action Repr. │ Absolute ←────→ Relative │
|
||||||
│ RelativeActionsProcessorStep (pre) │
|
│ RelativeActionsProcessorStep (pre) │
|
||||||
│ AbsoluteActionsProcessorStep (post) │
|
│ AbsoluteActionsProcessorStep (post) │
|
||||||
└─────────────────────────────────────────┘
|
└─────────────────────────────────────────┘
|
||||||
|
|
||||||
|
┌─────────────────────────────────────────┐
|
||||||
|
State Repr. │ Absolute ────→ Relative │
|
||||||
|
│ RelativeStateProcessorStep (pre only) │
|
||||||
|
└─────────────────────────────────────────┘
|
||||||
|
|
||||||
┌─────────────────────────────────────────┐
|
┌─────────────────────────────────────────┐
|
||||||
Normalization │ Raw ←────→ Normalized │
|
Normalization │ Raw ←────→ Normalized │
|
||||||
│ NormalizerProcessorStep (pre) │
|
│ NormalizerProcessorStep (pre) │
|
||||||
@@ -216,6 +227,10 @@ Here is how the different processors compose. Each arrow is a processor step, an
|
|||||||
|
|
||||||
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
|
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
|
||||||
|
|
||||||
|
With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output).
|
||||||
|
|
||||||
|
With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details.
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
|
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
|
||||||
|
|||||||
@@ -4,16 +4,13 @@ This guide explains how to prepare a UMI-collected dataset for training a pi0 po
|
|||||||
|
|
||||||
**What we will do:**
|
**What we will do:**
|
||||||
|
|
||||||
1. How to add `observation.state` to an existing UMI LeRobot dataset.
|
1. Recompute dataset statistics for relative actions and state.
|
||||||
2. How to train pi0 with `use_relative_actions=True`.
|
2. Train pi0 with `derive_state_from_action=true` (full UMI pipeline).
|
||||||
3. How to evaluate the trained policy on a real robot.
|
3. Evaluate the trained policy on a real robot.
|
||||||
|
|
||||||
## Background
|
## Background
|
||||||
|
|
||||||
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need two additions:
|
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need **relative action statistics** — so the normalizer sees `(action − state)` distributions.
|
||||||
|
|
||||||
1. **`observation.state`** — the current EE pose the policy conditions on.
|
|
||||||
2. **Relative action statistics** — so the normalizer sees `(action − state)` distributions.
|
|
||||||
|
|
||||||
### Why relative actions?
|
### Why relative actions?
|
||||||
|
|
||||||
@@ -25,72 +22,39 @@ relative_action[i] = absolute_action[t + i] − state[t]
|
|||||||
|
|
||||||
This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison.
|
This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison.
|
||||||
|
|
||||||
## State-Action Offset
|
### Full UMI mode: `derive_state_from_action`
|
||||||
|
|
||||||
UMI SLAM produces a single trajectory of EE poses stored as `action`. We derive `observation.state` from the same trajectory with a configurable offset:
|
When `derive_state_from_action=true`, pi0 automatically derives `observation.state` from the action column on the fly — no separate state column or dataset conversion step needed. Under the hood:
|
||||||
|
|
||||||
```
|
- `action_delta_indices` extends to `[-1, 0, 1, ..., 49]` (one extra leading timestep).
|
||||||
state[t] = action[t - offset]
|
- `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk.
|
||||||
```
|
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
|
||||||
|
- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens.
|
||||||
|
|
||||||
| Offset | `state[t]` | Meaning |
|
This single flag implies `use_relative_state=true` and `state_obs_steps=2`.
|
||||||
| ------ | ------------- | ---------------------------------------------------------------- |
|
|
||||||
| 0 | `action[t]` | State and action are the same pose at time t |
|
|
||||||
| 1 | `action[t-1]` | State is the previous action — where the gripper already arrived |
|
|
||||||
|
|
||||||
An offset of 1 is the typical UMI convention: at decision time the "current state" is where the gripper _already is_ (the result of the previous command), and the action is where it should go next. At episode boundaries where `t < offset`, we clamp to `action[0]`.
|
During **inference**, state comes from the robot (via FK), so `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically.
|
||||||
|
|
||||||
## Step 1: Add `observation.state`
|
## Step 1: Recompute Stats
|
||||||
|
|
||||||
pi0 with `use_relative_actions=True` needs `observation.state` in the dataset to compute `action - state` on the fly. The script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` adds it. Edit the constants at the top:
|
Use the built-in CLI to recompute dataset statistics for relative actions and derive-state-from-action:
|
||||||
|
|
||||||
```python
|
|
||||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
|
||||||
|
|
||||||
# Option A: Copy an existing feature as observation.state
|
|
||||||
STATE_SOURCE_FEATURE = "observation.joints" # or "observation.pose", etc.
|
|
||||||
|
|
||||||
# Option B: Derive from action with offset (set STATE_SOURCE_FEATURE = None)
|
|
||||||
STATE_SOURCE_FEATURE = None
|
|
||||||
STATE_ACTION_OFFSET = 1
|
|
||||||
```
|
|
||||||
|
|
||||||
**Choosing the state source:**
|
|
||||||
|
|
||||||
- If your dataset already has a feature in the same space as `action` (e.g. `observation.joints` for joint-space actions, or `observation.pose` for EE-space actions), set `STATE_SOURCE_FEATURE` to copy it.
|
|
||||||
- If your dataset only has a single trajectory (like raw UMI EE data where action = the EE poses), set `STATE_SOURCE_FEATURE = None` and use `STATE_ACTION_OFFSET` to derive state from the action column with a time offset.
|
|
||||||
|
|
||||||
`observation.state` **must have the same shape as `action`** — the relative conversion computes `action - state` element-wise.
|
|
||||||
|
|
||||||
Then run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python examples/umi_pi0_relative_ee/convert_umi_dataset.py
|
|
||||||
```
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
If your dataset already has `observation.state`, the script exits early — nothing to do.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
## Step 2: Recompute Relative Action Stats
|
|
||||||
|
|
||||||
Use the built-in CLI to recompute dataset statistics in relative space:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id <your_dataset> \
|
--repo_id <your_dataset> \
|
||||||
--operation.type recompute_stats \
|
--operation.type recompute_stats \
|
||||||
--operation.relative_action true \
|
--operation.relative_action true \
|
||||||
|
--operation.derive_state_from_action true \
|
||||||
--operation.chunk_size 50 \
|
--operation.chunk_size 50 \
|
||||||
--operation.relative_exclude_joints "['gripper']" \
|
--operation.relative_exclude_joints "['gripper']" \
|
||||||
--push_to_hub true
|
--push_to_hub true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `derive_state_from_action` flag tells `recompute_stats` to read from the action column (instead of `observation.state`) when computing relative state stats. It automatically enables `relative_state=true` and `state_obs_steps=2`.
|
||||||
|
|
||||||
The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative.
|
The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative.
|
||||||
|
|
||||||
## Step 3: Train
|
## Step 2: Train
|
||||||
|
|
||||||
No custom training script is needed — standard `lerobot-train` handles everything:
|
No custom training script is needed — standard `lerobot-train` handles everything:
|
||||||
|
|
||||||
@@ -99,19 +63,26 @@ lerobot-train \
|
|||||||
--dataset.repo_id=<hf_username>/<dataset_repo_id> \
|
--dataset.repo_id=<hf_username>/<dataset_repo_id> \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--policy.pretrained_path=lerobot/pi0_base \
|
--policy.pretrained_path=lerobot/pi0_base \
|
||||||
|
--policy.derive_state_from_action=true \
|
||||||
--policy.use_relative_actions=true \
|
--policy.use_relative_actions=true \
|
||||||
--policy.relative_exclude_joints='["gripper"]'
|
--policy.relative_exclude_joints='["gripper"]'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`derive_state_from_action=true` auto-enables `use_relative_state=true` and `state_obs_steps=2`.
|
||||||
|
|
||||||
Under the hood, the training pipeline:
|
Under the hood, the training pipeline:
|
||||||
|
|
||||||
- Loads relative action stats from the dataset's `meta/stats.json`.
|
- Loads relative action stats and relative state stats from the dataset's `meta/stats.json`.
|
||||||
- Configures `RelativeActionsProcessorStep` in the preprocessor (absolute → relative before normalization).
|
- Extends `action_delta_indices` to `[-1, 0, 1, ..., 49]` to load one extra leading timestep.
|
||||||
- The model trains on normalized relative action values.
|
- `DeriveStateFromActionStep` extracts the 2-step state from the action chunk and strips the extra timestep.
|
||||||
|
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
|
||||||
|
- `RelativeStateProcessorStep` converts the 2-step state to relative offsets from the current timestep, then flattens.
|
||||||
|
- `NormalizerProcessorStep` normalizes everything.
|
||||||
|
- The model trains on normalized relative values.
|
||||||
|
|
||||||
See the [pi0 documentation](pi0) for all available training options.
|
See the [pi0 documentation](pi0) for all available training options.
|
||||||
|
|
||||||
## Step 4: Evaluate
|
## Step 3: Evaluate
|
||||||
|
|
||||||
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space):
|
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space):
|
||||||
|
|
||||||
@@ -121,10 +92,20 @@ python examples/umi_pi0_relative_ee/evaluate.py
|
|||||||
|
|
||||||
Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file.
|
Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file.
|
||||||
|
|
||||||
|
### Latency compensation
|
||||||
|
|
||||||
|
For real robot deployment, you may want to skip the first few steps of each predicted action chunk to compensate for system latency. Set `LATENCY_SKIP_STEPS` in the evaluate script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
LATENCY_SKIP_STEPS = 0 # ceil(total_latency_ms / (1000 / FPS))
|
||||||
|
```
|
||||||
|
|
||||||
|
For example, at 10Hz with ~200ms total latency, set `LATENCY_SKIP_STEPS = 2`.
|
||||||
|
|
||||||
The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed:
|
The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed:
|
||||||
|
|
||||||
1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`.
|
1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`.
|
||||||
2. **Preprocessor** — `RelativeActionsProcessorStep` caches the raw `observation.state`, then `NormalizerProcessorStep` normalizes everything.
|
2. **Preprocessor** — `DeriveStateFromActionStep` is a no-op (state comes from robot). `RelativeStateProcessorStep` buffers previous state, stacks, and converts to relative. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes.
|
||||||
3. **pi0 inference** — The model predicts a normalized relative action chunk.
|
3. **pi0 inference** — The model predicts a normalized relative action chunk.
|
||||||
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets.
|
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets.
|
||||||
5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands.
|
5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands.
|
||||||
@@ -132,14 +113,22 @@ The inference flow uses pi0's built-in processor pipeline — no custom wrappers
|
|||||||
## How the Pieces Fit Together
|
## How the Pieces Fit Together
|
||||||
|
|
||||||
```
|
```
|
||||||
Training:
|
Training (full UMI mode: derive_state_from_action=true):
|
||||||
dataset (absolute EE) → RelativeActionsProcessorStep → NormalizerProcessorStep → pi0 model
|
DataLoader (action: B,51,D)
|
||||||
|
→ DeriveStateFromActionStep (state = action[:,:2,:], action = action[:,1:,:])
|
||||||
|
→ RelativeActionsProcessorStep (action -= state[:,-1,:])
|
||||||
|
→ RelativeStateProcessorStep (state offsets from current, flatten → B,2*D)
|
||||||
|
→ NormalizerProcessorStep → pi0 model
|
||||||
|
|
||||||
Inference:
|
Inference:
|
||||||
robot joints → FK → observation.state (absolute EE)
|
robot joints → FK → observation.state (absolute EE)
|
||||||
↓
|
↓
|
||||||
|
DeriveStateFromActionStep (no-op)
|
||||||
|
↓
|
||||||
RelativeActionsProcessorStep (caches state)
|
RelativeActionsProcessorStep (caches state)
|
||||||
↓
|
↓
|
||||||
|
RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens)
|
||||||
|
↓
|
||||||
NormalizerProcessorStep → pi0 model → relative action chunk
|
NormalizerProcessorStep → pi0 model → relative action chunk
|
||||||
↓
|
↓
|
||||||
UnnormalizerProcessorStep
|
UnnormalizerProcessorStep
|
||||||
@@ -149,6 +138,31 @@ Inference:
|
|||||||
IK → joint targets → robot
|
IK → joint targets → robot
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Manual mode (without derive_state_from_action)
|
||||||
|
|
||||||
|
If your dataset already has `observation.state` (or you want to add it separately), you can skip `derive_state_from_action` and use relative actions + relative state independently:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recompute stats
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id <your_dataset> \
|
||||||
|
--operation.type recompute_stats \
|
||||||
|
--operation.relative_action true \
|
||||||
|
--operation.relative_state true \
|
||||||
|
--operation.state_obs_steps 2 \
|
||||||
|
--operation.chunk_size 50 \
|
||||||
|
--operation.relative_exclude_joints "['gripper']"
|
||||||
|
|
||||||
|
# Train
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=<your_dataset> \
|
||||||
|
--policy.type=pi0 \
|
||||||
|
--policy.use_relative_actions=true \
|
||||||
|
--policy.use_relative_state=true \
|
||||||
|
--policy.state_obs_steps=2 \
|
||||||
|
--policy.relative_exclude_joints='["gripper"]'
|
||||||
|
```
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
|
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Add ``observation.state`` to an existing LeRobot dataset.
|
|
||||||
|
|
||||||
pi0 uses ``observation.state`` as its proprioceptive input AND for
|
|
||||||
relative action conversion (action − state). This script creates
|
|
||||||
``observation.state`` by concatenating one or more existing features.
|
|
||||||
|
|
||||||
Ordering matters: the features whose dimensions correspond to ``action``
|
|
||||||
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
|
|
||||||
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
|
|
||||||
pose) are appended after and are seen by the model but not used for
|
|
||||||
relative conversion.
|
|
||||||
|
|
||||||
Example: action = [proximal, distal], and we want the model to also see
|
|
||||||
the EE pose:
|
|
||||||
|
|
||||||
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
|
|
||||||
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
|
|
||||||
|
|
||||||
The relative conversion uses state[:2] = [proximal, distal] to subtract
|
|
||||||
from action[:2], and the model sees all 8 dimensions.
|
|
||||||
|
|
||||||
After running this script, recompute relative action stats:
|
|
||||||
|
|
||||||
lerobot-edit-dataset \\
|
|
||||||
--repo_id <your_dataset> \\
|
|
||||||
--operation.type recompute_stats \\
|
|
||||||
--operation.relative_action true \\
|
|
||||||
--operation.chunk_size 50 \\
|
|
||||||
--operation.relative_exclude_joints "[]" \\
|
|
||||||
--push_to_hub true
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python convert_umi_dataset.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from lerobot.datasets.dataset_tools import add_features
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
HF_DATASET_ID = ""
|
|
||||||
|
|
||||||
# Output repo ID. Set to None for default "<input>_modified".
|
|
||||||
OUTPUT_REPO_ID: str | None = None
|
|
||||||
|
|
||||||
# Features to concatenate into observation.state. Order matters:
|
|
||||||
# action-matching features FIRST, then extra proprioception.
|
|
||||||
# Set to a single string to copy one feature directly.
|
|
||||||
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
|
|
||||||
|
|
||||||
# Only used when STATE_SOURCE_FEATURES is None:
|
|
||||||
# derive state from action with a per-episode offset.
|
|
||||||
STATE_ACTION_OFFSET = 1
|
|
||||||
|
|
||||||
# Push the augmented dataset to the Hugging Face Hub.
|
|
||||||
PUSH_TO_HUB = True
|
|
||||||
|
|
||||||
|
|
||||||
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
|
|
||||||
hf = dataset.hf_dataset
|
|
||||||
ep = np.array(hf["episode_index"])
|
|
||||||
fr = np.array(hf["frame_index"])
|
|
||||||
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
|
|
||||||
"""Concatenate multiple features into observation.state."""
|
|
||||||
hf = dataset.hf_dataset
|
|
||||||
key_to_global = _build_global_index(dataset)
|
|
||||||
|
|
||||||
columns = [hf[feat] for feat in source_features]
|
|
||||||
|
|
||||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
|
||||||
g = key_to_global[(ep_idx, frame_idx)]
|
|
||||||
parts = []
|
|
||||||
for col in columns:
|
|
||||||
val = col[g]
|
|
||||||
if hasattr(val, "tolist"):
|
|
||||||
flat = val.tolist()
|
|
||||||
if isinstance(flat, list):
|
|
||||||
parts.extend(flat)
|
|
||||||
else:
|
|
||||||
parts.append(flat)
|
|
||||||
elif isinstance(val, list):
|
|
||||||
parts.extend(val)
|
|
||||||
else:
|
|
||||||
parts.append(float(val))
|
|
||||||
return parts
|
|
||||||
|
|
||||||
return _get_state
|
|
||||||
|
|
||||||
|
|
||||||
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
|
|
||||||
"""Derive state from action with a per-episode offset."""
|
|
||||||
hf = dataset.hf_dataset
|
|
||||||
episode_indices = np.array(hf["episode_index"])
|
|
||||||
frame_indices = np.array(hf["frame_index"])
|
|
||||||
|
|
||||||
ep_sorted: dict[int, list[tuple[int, int]]] = {}
|
|
||||||
for ep_idx in np.unique(episode_indices):
|
|
||||||
ep_mask = episode_indices == ep_idx
|
|
||||||
ep_globals = np.where(ep_mask)[0]
|
|
||||||
ep_frames = frame_indices[ep_globals]
|
|
||||||
order = np.argsort(ep_frames)
|
|
||||||
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
|
|
||||||
|
|
||||||
ep_frame_to_local: dict[int, dict[int, int]] = {}
|
|
||||||
for ep_idx, sorted_list in ep_sorted.items():
|
|
||||||
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
|
|
||||||
|
|
||||||
actions = hf["action"]
|
|
||||||
|
|
||||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
|
||||||
local_t = ep_frame_to_local[ep_idx][frame_idx]
|
|
||||||
source_local = max(0, local_t - offset)
|
|
||||||
_, source_global = ep_sorted[ep_idx][source_local]
|
|
||||||
return actions[source_global]
|
|
||||||
|
|
||||||
return _get_state
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
logger.info(f"Loading dataset {HF_DATASET_ID}")
|
|
||||||
dataset = LeRobotDataset(HF_DATASET_ID)
|
|
||||||
|
|
||||||
if "observation.state" in dataset.features:
|
|
||||||
logger.info("observation.state already exists — nothing to do")
|
|
||||||
return
|
|
||||||
|
|
||||||
action_meta = dataset.features["action"]
|
|
||||||
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
|
|
||||||
|
|
||||||
if STATE_SOURCE_FEATURES is not None:
|
|
||||||
source_list = (
|
|
||||||
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
|
|
||||||
)
|
|
||||||
for feat in source_list:
|
|
||||||
if feat not in dataset.features:
|
|
||||||
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
|
|
||||||
|
|
||||||
# Compute combined shape and names
|
|
||||||
total_dim = 0
|
|
||||||
all_names = []
|
|
||||||
for feat in source_list:
|
|
||||||
meta = dataset.features[feat]
|
|
||||||
total_dim += meta["shape"][0]
|
|
||||||
names = meta.get("names")
|
|
||||||
if names:
|
|
||||||
all_names.extend(names)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
|
|
||||||
)
|
|
||||||
state_fn = _build_state_from_features(dataset, source_list)
|
|
||||||
state_feature_info = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": [total_dim],
|
|
||||||
"names": all_names or None,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
|
|
||||||
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
|
|
||||||
state_feature_info = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": list(action_meta["shape"]),
|
|
||||||
"names": action_meta.get("names"),
|
|
||||||
}
|
|
||||||
|
|
||||||
augmented = add_features(
|
|
||||||
dataset,
|
|
||||||
features={"observation.state": (state_fn, state_feature_info)},
|
|
||||||
repo_id=OUTPUT_REPO_ID,
|
|
||||||
)
|
|
||||||
logger.info("observation.state added")
|
|
||||||
|
|
||||||
if PUSH_TO_HUB:
|
|
||||||
logger.info(f"Pushing to Hub: {augmented.repo_id}")
|
|
||||||
augmented.push_to_hub()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Done. Dataset at: {augmented.root}\n"
|
|
||||||
"Now recompute relative action stats:\n"
|
|
||||||
" lerobot-edit-dataset \\\n"
|
|
||||||
f" --repo_id {augmented.repo_id} \\\n"
|
|
||||||
" --operation.type recompute_stats \\\n"
|
|
||||||
" --operation.relative_action true \\\n"
|
|
||||||
" --operation.chunk_size 50 \\\n"
|
|
||||||
' --operation.relative_exclude_joints "[]" \\\n'
|
|
||||||
" --push_to_hub true"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -17,17 +17,20 @@
|
|||||||
"""
|
"""
|
||||||
Inference script for a pi0 model trained with **relative EE actions**.
|
Inference script for a pi0 model trained with **relative EE actions**.
|
||||||
|
|
||||||
This uses the built-in ``RelativeActionsProcessorStep`` and
|
This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
|
||||||
``AbsoluteActionsProcessorStep`` that are already wired into pi0's
|
``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
|
||||||
processor pipeline when ``use_relative_actions=True``.
|
``RelativeStateProcessorStep`` that are already wired into pi0's processor
|
||||||
|
pipeline.
|
||||||
|
|
||||||
The inference loop:
|
The inference loop:
|
||||||
1. Reads joint positions from the robot.
|
1. Reads joint positions from the robot.
|
||||||
2. Converts to EE pose via forward kinematics (FK).
|
2. Converts to EE pose via forward kinematics (FK).
|
||||||
This produces ``observation.state`` with the current EE pose.
|
This produces ``observation.state`` with the current EE pose.
|
||||||
3. The pi0 preprocessor:
|
3. The pi0 preprocessor:
|
||||||
a) ``RelativeActionsProcessorStep`` caches the raw state.
|
a) ``DeriveStateFromActionStep`` — no-op (state comes from robot).
|
||||||
b) ``NormalizerProcessorStep`` normalizes state and actions.
|
b) ``RelativeActionsProcessorStep`` caches the raw state.
|
||||||
|
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
|
||||||
|
d) ``NormalizerProcessorStep`` normalizes state and actions.
|
||||||
4. pi0 predicts relative action chunk.
|
4. pi0 predicts relative action chunk.
|
||||||
5. The pi0 postprocessor:
|
5. The pi0 postprocessor:
|
||||||
a) ``UnnormalizerProcessorStep`` unnormalizes.
|
a) ``UnnormalizerProcessorStep`` unnormalizes.
|
||||||
@@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics
|
|||||||
from lerobot.policies.factory import make_pre_post_processors
|
from lerobot.policies.factory import make_pre_post_processors
|
||||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
|
RelativeStateProcessorStep,
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
make_default_teleop_action_processor,
|
make_default_teleop_action_processor,
|
||||||
)
|
)
|
||||||
@@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task"
|
|||||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||||
|
|
||||||
|
# Latency compensation: skip this many steps from the start of each predicted
|
||||||
|
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
|
||||||
|
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
|
||||||
|
LATENCY_SKIP_STEPS = 0
|
||||||
|
|
||||||
# EE feature keys produced by ForwardKinematicsJointsToEE
|
# EE feature keys produced by ForwardKinematicsJointsToEE
|
||||||
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||||
|
|
||||||
@@ -94,6 +103,7 @@ def main():
|
|||||||
robot = SO100Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
|
|
||||||
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
|
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
|
||||||
|
|
||||||
kinematics_solver = RobotKinematics(
|
kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
@@ -151,9 +161,8 @@ def main():
|
|||||||
|
|
||||||
# Build pre/post processors from the trained model.
|
# Build pre/post processors from the trained model.
|
||||||
# The pi0 processor pipeline already includes:
|
# The pi0 processor pipeline already includes:
|
||||||
# pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep
|
# pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep
|
||||||
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
|
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
|
||||||
# These handle the relative ↔ absolute conversion automatically.
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=policy,
|
policy_cfg=policy,
|
||||||
pretrained_path=HF_MODEL_ID,
|
pretrained_path=HF_MODEL_ID,
|
||||||
@@ -161,6 +170,9 @@ def main():
|
|||||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Find the relative state step (if present) so we can reset its buffer between episodes.
|
||||||
|
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
|
||||||
|
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
@@ -174,6 +186,10 @@ def main():
|
|||||||
for episode_idx in range(NUM_EPISODES):
|
for episode_idx in range(NUM_EPISODES):
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
|
# Reset relative state buffer so velocity is zero at episode start
|
||||||
|
for step in _relative_state_steps:
|
||||||
|
step.reset()
|
||||||
|
|
||||||
record_loop(
|
record_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
"""Delta indices specifically for observation.state.
|
||||||
|
|
||||||
|
When not None, overrides ``observation_delta_indices`` for the
|
||||||
|
``observation.state`` key only. Useful for loading state history
|
||||||
|
(e.g. ``[-1, 0]`` for UMI-style relative proprioception) without
|
||||||
|
also loading multiple image timesteps.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -767,3 +767,94 @@ def compute_relative_action_stats(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def compute_relative_state_stats(
|
||||||
|
hf_dataset,
|
||||||
|
features: dict,
|
||||||
|
state_obs_steps: int = 2,
|
||||||
|
exclude_joints: list[str] | None = None,
|
||||||
|
source_key: str = OBS_STATE,
|
||||||
|
) -> dict[str, np.ndarray]:
|
||||||
|
"""Compute normalization statistics for observation.state after relative conversion.
|
||||||
|
|
||||||
|
For UMI-style relative proprioception with ``state_obs_steps`` timesteps,
|
||||||
|
each state observation becomes a stack of offsets from the current timestep:
|
||||||
|
``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``.
|
||||||
|
|
||||||
|
The stats are computed over the flattened ``[state_obs_steps * state_dim]``
|
||||||
|
vector that the model actually sees after ``prepare_state`` flattening.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_dataset: The HuggingFace dataset with the source column and
|
||||||
|
"episode_index" columns.
|
||||||
|
features: Dataset feature metadata.
|
||||||
|
state_obs_steps: Number of observation timesteps (must be >= 2).
|
||||||
|
exclude_joints: State dimension names to keep absolute.
|
||||||
|
source_key: Column to read data from. Defaults to "observation.state".
|
||||||
|
When ``derive_state_from_action=True``, pass ``ACTION`` to read
|
||||||
|
from the action column instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
|
||||||
|
"""
|
||||||
|
from lerobot.processor.relative_action_processor import RelativeStateProcessorStep
|
||||||
|
|
||||||
|
if exclude_joints is None:
|
||||||
|
exclude_joints = []
|
||||||
|
|
||||||
|
state_dim = features[source_key]["shape"][0]
|
||||||
|
state_names = features.get(source_key, {}).get("names")
|
||||||
|
mask_step = RelativeStateProcessorStep(
|
||||||
|
enabled=True,
|
||||||
|
exclude_joints=exclude_joints,
|
||||||
|
state_names=state_names,
|
||||||
|
)
|
||||||
|
relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32)
|
||||||
|
|
||||||
|
logging.info(f"Loading data from '{source_key}' for relative state stats...")
|
||||||
|
all_states = np.array(hf_dataset[source_key], dtype=np.float32)
|
||||||
|
episode_indices = np.array(hf_dataset["episode_index"])
|
||||||
|
|
||||||
|
# Build all valid windows of length state_obs_steps within each episode
|
||||||
|
n = len(all_states)
|
||||||
|
if n < state_obs_steps:
|
||||||
|
raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}")
|
||||||
|
|
||||||
|
max_start = n - state_obs_steps
|
||||||
|
starts = np.arange(max_start + 1)
|
||||||
|
valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1]
|
||||||
|
valid_starts = starts[valid]
|
||||||
|
|
||||||
|
if len(valid_starts) == 0:
|
||||||
|
raise RuntimeError("No valid state windows found within single episodes")
|
||||||
|
|
||||||
|
offsets = np.arange(state_obs_steps)
|
||||||
|
mask_dim = len(relative_mask)
|
||||||
|
|
||||||
|
running_stats = RunningQuantileStats()
|
||||||
|
|
||||||
|
batch_size = 50_000
|
||||||
|
for i in range(0, len(valid_starts), batch_size):
|
||||||
|
batch_starts = valid_starts[i : i + batch_size]
|
||||||
|
frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps]
|
||||||
|
windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim]
|
||||||
|
|
||||||
|
# Subtract current (last) timestep from all timesteps for masked dims
|
||||||
|
current = windows[:, -1:, :] # [N, 1, state_dim]
|
||||||
|
windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :]
|
||||||
|
|
||||||
|
# Flatten to [N, state_obs_steps * state_dim] (same as prepare_state)
|
||||||
|
flattened = windows.reshape(len(batch_starts), -1)
|
||||||
|
running_stats.update(flattened)
|
||||||
|
|
||||||
|
stats = running_stats.get_statistics()
|
||||||
|
|
||||||
|
excluded_dims = int(mask_dim - relative_mask.sum())
|
||||||
|
logging.info(
|
||||||
|
f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): "
|
||||||
|
f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), "
|
||||||
|
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from lerobot.datasets.compute_stats import (
|
|||||||
aggregate_stats,
|
aggregate_stats,
|
||||||
compute_episode_stats,
|
compute_episode_stats,
|
||||||
compute_relative_action_stats,
|
compute_relative_action_stats,
|
||||||
|
compute_relative_state_stats,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.io_utils import (
|
from lerobot.datasets.io_utils import (
|
||||||
@@ -1544,6 +1545,10 @@ def recompute_stats(
|
|||||||
relative_exclude_joints: list[str] | None = None,
|
relative_exclude_joints: list[str] | None = None,
|
||||||
chunk_size: int = 50,
|
chunk_size: int = 50,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
|
relative_state: bool = False,
|
||||||
|
relative_exclude_state_joints: list[str] | None = None,
|
||||||
|
state_obs_steps: int = 2,
|
||||||
|
derive_state_from_action: bool = False,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
"""Recompute stats.json from scratch by iterating all episodes.
|
"""Recompute stats.json from scratch by iterating all episodes.
|
||||||
|
|
||||||
@@ -1561,10 +1566,22 @@ def recompute_stats(
|
|||||||
``policy.chunk_size``. Only used when ``relative_action=True``.
|
``policy.chunk_size``. Only used when ``relative_action=True``.
|
||||||
num_workers: Number of parallel threads for relative action stats computation.
|
num_workers: Number of parallel threads for relative action stats computation.
|
||||||
Values ≤1 mean single-threaded. Only used when ``relative_action=True``.
|
Values ≤1 mean single-threaded. Only used when ``relative_action=True``.
|
||||||
|
relative_state: If True, compute observation.state stats in relative space
|
||||||
|
(multi-timestep offsets from current). This matches the normalization
|
||||||
|
the model sees during training with ``use_relative_state=True``.
|
||||||
|
relative_exclude_state_joints: State dim names to exclude from relative conversion.
|
||||||
|
state_obs_steps: Number of observation timesteps for relative state stats.
|
||||||
|
Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``.
|
||||||
|
derive_state_from_action: If True, compute relative state stats from the
|
||||||
|
action column instead of observation.state. Implies ``relative_state=True``
|
||||||
|
and ``state_obs_steps=2``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The same dataset with updated stats.
|
The same dataset with updated stats.
|
||||||
"""
|
"""
|
||||||
|
if derive_state_from_action:
|
||||||
|
relative_state = True
|
||||||
|
state_obs_steps = 2
|
||||||
features = dataset.meta.features
|
features = dataset.meta.features
|
||||||
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
|
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
|
||||||
numeric_features = {
|
numeric_features = {
|
||||||
@@ -1596,6 +1613,20 @@ def recompute_stats(
|
|||||||
)
|
)
|
||||||
features_to_compute.pop(ACTION, None)
|
features_to_compute.pop(ACTION, None)
|
||||||
|
|
||||||
|
# When relative_state is enabled, compute state stats over the flattened
|
||||||
|
# multi-timestep relative representation (matching what the model sees).
|
||||||
|
relative_state_stats = None
|
||||||
|
if relative_state and (OBS_STATE in features or derive_state_from_action):
|
||||||
|
source_key = ACTION if derive_state_from_action else OBS_STATE
|
||||||
|
relative_state_stats = compute_relative_state_stats(
|
||||||
|
hf_dataset=dataset.hf_dataset,
|
||||||
|
features=features,
|
||||||
|
state_obs_steps=state_obs_steps,
|
||||||
|
exclude_joints=relative_exclude_state_joints,
|
||||||
|
source_key=source_key,
|
||||||
|
)
|
||||||
|
features_to_compute.pop(OBS_STATE, None)
|
||||||
|
|
||||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||||
|
|
||||||
data_dir = dataset.root / DATA_DIR
|
data_dir = dataset.root / DATA_DIR
|
||||||
@@ -1632,6 +1663,9 @@ def recompute_stats(
|
|||||||
if relative_action_stats is not None:
|
if relative_action_stats is not None:
|
||||||
new_stats[ACTION] = relative_action_stats
|
new_stats[ACTION] = relative_action_stats
|
||||||
|
|
||||||
|
if relative_state_stats is not None:
|
||||||
|
new_stats[OBS_STATE] = relative_state_stats
|
||||||
|
|
||||||
# Merge: keep existing stats for features we didn't recompute
|
# Merge: keep existing stats for features we didn't recompute
|
||||||
if dataset.meta.stats:
|
if dataset.meta.stats:
|
||||||
for key, value in dataset.meta.stats.items():
|
for key, value in dataset.meta.stats.items():
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
from lerobot.datasets.transforms import ImageTransforms
|
from lerobot.datasets.transforms import ImageTransforms
|
||||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD
|
||||||
|
|
||||||
IMAGENET_STATS = {
|
IMAGENET_STATS = {
|
||||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||||
@@ -52,12 +52,15 @@ def resolve_delta_timestamps(
|
|||||||
returns `None` if the resulting dict is empty.
|
returns `None` if the resulting dict is empty.
|
||||||
"""
|
"""
|
||||||
delta_timestamps = {}
|
delta_timestamps = {}
|
||||||
|
state_delta = getattr(cfg, "state_delta_indices", None)
|
||||||
for key in ds_meta.features:
|
for key in ds_meta.features:
|
||||||
if key == REWARD and cfg.reward_delta_indices is not None:
|
if key == REWARD and cfg.reward_delta_indices is not None:
|
||||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||||
if key == ACTION and cfg.action_delta_indices is not None:
|
if key == ACTION and cfg.action_delta_indices is not None:
|
||||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
if key == OBS_STATE and state_delta is not None:
|
||||||
|
delta_timestamps[key] = [i / ds_meta.fps for i in state_delta]
|
||||||
|
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||||
|
|
||||||
if len(delta_timestamps) == 0:
|
if len(delta_timestamps) == 0:
|
||||||
|
|||||||
@@ -57,6 +57,28 @@ class PI0Config(PreTrainedConfig):
|
|||||||
# Populated at runtime from dataset metadata by make_policy.
|
# Populated at runtime from dataset metadata by make_policy.
|
||||||
action_feature_names: list[str] | None = None
|
action_feature_names: list[str] | None = None
|
||||||
|
|
||||||
|
# Relative state (UMI-style relative proprioception): converts multi-timestep
|
||||||
|
# observation.state to offsets from the current timestep, providing velocity info.
|
||||||
|
# Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to
|
||||||
|
# max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim.
|
||||||
|
use_relative_state: bool = False
|
||||||
|
state_obs_steps: int = 1
|
||||||
|
relative_exclude_state_joints: list[str] = field(default_factory=list)
|
||||||
|
# Populated at runtime from dataset metadata by make_policy.
|
||||||
|
state_feature_names: list[str] | None = None
|
||||||
|
|
||||||
|
# Derive observation.state from the action column (UMI-style).
|
||||||
|
# When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1],
|
||||||
|
# DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state,
|
||||||
|
# and strips the extra timestep from the action chunk.
|
||||||
|
# Implies use_relative_state=True and state_obs_steps=2.
|
||||||
|
derive_state_from_action: bool = False
|
||||||
|
|
||||||
|
# Latency compensation: skip this many steps from the start of each predicted
|
||||||
|
# action chunk during inference. E.g. at 10Hz with ~200ms total latency,
|
||||||
|
# latency_skip_steps=2 compensates for the delay.
|
||||||
|
latency_skip_steps: int = 0
|
||||||
|
|
||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
@@ -106,6 +128,10 @@ class PI0Config(PreTrainedConfig):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.derive_state_from_action:
|
||||||
|
self.use_relative_state = True
|
||||||
|
self.state_obs_steps = 2
|
||||||
|
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -121,6 +147,13 @@ class PI0Config(PreTrainedConfig):
|
|||||||
if self.dtype not in ["bfloat16", "float32"]:
|
if self.dtype not in ["bfloat16", "float32"]:
|
||||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||||
|
|
||||||
|
if self.use_relative_state and self.state_obs_steps < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"use_relative_state requires state_obs_steps >= 2 "
|
||||||
|
f"(got {self.state_obs_steps}). Set state_obs_steps=2 for "
|
||||||
|
"UMI-style relative proprioception."
|
||||||
|
)
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate and set up input/output features."""
|
"""Validate and set up input/output features."""
|
||||||
for i in range(self.empty_cameras):
|
for i in range(self.empty_cameras):
|
||||||
@@ -166,8 +199,16 @@ class PI0Config(PreTrainedConfig):
|
|||||||
def observation_delta_indices(self) -> None:
|
def observation_delta_indices(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_delta_indices(self) -> list[int] | None:
|
||||||
|
if self.state_obs_steps >= 2:
|
||||||
|
return list(range(-(self.state_obs_steps - 1), 1))
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
|
if self.derive_state_from_action:
|
||||||
|
return [-1] + list(range(self.chunk_size))
|
||||||
return list(range(self.chunk_size))
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -1230,8 +1230,11 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
return images, img_masks
|
return images, img_masks
|
||||||
|
|
||||||
def prepare_state(self, batch):
|
def prepare_state(self, batch):
|
||||||
"""Pad state"""
|
"""Flatten multi-timestep state and pad to max_state_dim."""
|
||||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
state = batch[OBS_STATE]
|
||||||
|
if state.ndim == 3:
|
||||||
|
state = state.flatten(start_dim=1)
|
||||||
|
state = pad_vector(state, self.config.max_state_dim)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def prepare_action(self, batch):
|
def prepare_action(self, batch):
|
||||||
@@ -1250,7 +1253,8 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Action queue logic for n_action_steps > 1
|
# Action queue logic for n_action_steps > 1
|
||||||
if len(self._action_queue) == 0:
|
if len(self._action_queue) == 0:
|
||||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
skip = self.config.latency_skip_steps
|
||||||
|
actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps]
|
||||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from lerobot.processor import (
|
|||||||
AbsoluteActionsProcessorStep,
|
AbsoluteActionsProcessorStep,
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
ComplementaryDataProcessorStep,
|
ComplementaryDataProcessorStep,
|
||||||
|
DeriveStateFromActionStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
@@ -31,6 +32,7 @@ from lerobot.processor import (
|
|||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
RelativeActionsProcessorStep,
|
RelativeActionsProcessorStep,
|
||||||
|
RelativeStateProcessorStep,
|
||||||
RenameObservationsProcessorStep,
|
RenameObservationsProcessorStep,
|
||||||
TokenizerProcessorStep,
|
TokenizerProcessorStep,
|
||||||
UnnormalizerProcessorStep,
|
UnnormalizerProcessorStep,
|
||||||
@@ -128,13 +130,25 @@ def make_pi0_pre_post_processors(
|
|||||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
derive_state_step = DeriveStateFromActionStep(
|
||||||
|
enabled=getattr(config, "derive_state_from_action", False),
|
||||||
|
)
|
||||||
|
|
||||||
relative_step = RelativeActionsProcessorStep(
|
relative_step = RelativeActionsProcessorStep(
|
||||||
enabled=config.use_relative_actions,
|
enabled=config.use_relative_actions,
|
||||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||||
action_names=getattr(config, "action_feature_names", None),
|
action_names=getattr(config, "action_feature_names", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
|
relative_state_step = RelativeStateProcessorStep(
|
||||||
|
enabled=getattr(config, "use_relative_state", False),
|
||||||
|
exclude_joints=getattr(config, "relative_exclude_state_joints", []),
|
||||||
|
state_names=getattr(config, "state_feature_names", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Order: DeriveStateFromAction extracts state from the extended action chunk,
|
||||||
|
# then relative_action uses current state[t] for subtraction,
|
||||||
|
# then relative_state converts the multi-timestep state to offsets.
|
||||||
input_steps: list[ProcessorStep] = [
|
input_steps: list[ProcessorStep] = [
|
||||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||||
AddBatchDimensionProcessorStep(),
|
AddBatchDimensionProcessorStep(),
|
||||||
@@ -146,7 +160,9 @@ def make_pi0_pre_post_processors(
|
|||||||
padding="max_length",
|
padding="max_length",
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
|
derive_state_step,
|
||||||
relative_step,
|
relative_step,
|
||||||
|
relative_state_step,
|
||||||
NormalizerProcessorStep(
|
NormalizerProcessorStep(
|
||||||
features={**config.input_features, **config.output_features},
|
features={**config.input_features, **config.output_features},
|
||||||
norm_map=config.normalization_mapping,
|
norm_map=config.normalization_mapping,
|
||||||
|
|||||||
@@ -77,9 +77,12 @@ from .policy_robot_bridge import (
|
|||||||
)
|
)
|
||||||
from .relative_action_processor import (
|
from .relative_action_processor import (
|
||||||
AbsoluteActionsProcessorStep,
|
AbsoluteActionsProcessorStep,
|
||||||
|
DeriveStateFromActionStep,
|
||||||
RelativeActionsProcessorStep,
|
RelativeActionsProcessorStep,
|
||||||
|
RelativeStateProcessorStep,
|
||||||
to_absolute_actions,
|
to_absolute_actions,
|
||||||
to_relative_actions,
|
to_relative_actions,
|
||||||
|
to_relative_state,
|
||||||
)
|
)
|
||||||
from .rename_processor import RenameObservationsProcessorStep
|
from .rename_processor import RenameObservationsProcessorStep
|
||||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||||
@@ -107,7 +110,9 @@ __all__ = [
|
|||||||
"make_default_robot_action_processor",
|
"make_default_robot_action_processor",
|
||||||
"make_default_robot_observation_processor",
|
"make_default_robot_observation_processor",
|
||||||
"AbsoluteActionsProcessorStep",
|
"AbsoluteActionsProcessorStep",
|
||||||
|
"DeriveStateFromActionStep",
|
||||||
"RelativeActionsProcessorStep",
|
"RelativeActionsProcessorStep",
|
||||||
|
"RelativeStateProcessorStep",
|
||||||
"MapDeltaActionToRobotActionStep",
|
"MapDeltaActionToRobotActionStep",
|
||||||
"MapTensorToDeltaActionDictStep",
|
"MapTensorToDeltaActionDictStep",
|
||||||
"NormalizerProcessorStep",
|
"NormalizerProcessorStep",
|
||||||
@@ -139,6 +144,7 @@ __all__ = [
|
|||||||
"TruncatedProcessorStep",
|
"TruncatedProcessorStep",
|
||||||
"to_absolute_actions",
|
"to_absolute_actions",
|
||||||
"to_relative_actions",
|
"to_relative_actions",
|
||||||
|
"to_relative_state",
|
||||||
"UnnormalizerProcessorStep",
|
"UnnormalizerProcessorStep",
|
||||||
"VanillaObservationProcessorStep",
|
"VanillaObservationProcessorStep",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -30,10 +30,13 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"MapDeltaActionToRobotActionStep",
|
"MapDeltaActionToRobotActionStep",
|
||||||
"MapTensorToDeltaActionDictStep",
|
"MapTensorToDeltaActionDictStep",
|
||||||
|
"DeriveStateFromActionStep",
|
||||||
"RelativeActionsProcessorStep",
|
"RelativeActionsProcessorStep",
|
||||||
"AbsoluteActionsProcessorStep",
|
"AbsoluteActionsProcessorStep",
|
||||||
|
"RelativeStateProcessorStep",
|
||||||
"to_relative_actions",
|
"to_relative_actions",
|
||||||
"to_absolute_actions",
|
"to_absolute_actions",
|
||||||
|
"to_relative_state",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -81,6 +84,41 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("derive_state_from_action_processor")
|
||||||
|
@dataclass
|
||||||
|
class DeriveStateFromActionStep(ProcessorStep):
|
||||||
|
"""Derives 2-step observation.state from the action chunk (UMI-style).
|
||||||
|
|
||||||
|
Expects action with one extra leading timestep: [B, chunk_size+1, D]
|
||||||
|
from action_delta_indices = [-1, 0, 1, ..., chunk_size-1].
|
||||||
|
Extracts [action[t-1], action[t]] as state and strips the extra timestep.
|
||||||
|
No-op during inference (state comes from robot).
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
if not self.enabled:
|
||||||
|
return transition
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is None or action.ndim < 3:
|
||||||
|
return transition
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
|
||||||
|
new_obs[OBS_STATE] = action[..., :2, :]
|
||||||
|
new_transition[TransitionKey.ACTION] = action[..., 1:, :]
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {"enabled": self.enabled}
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||||
@dataclass
|
@dataclass
|
||||||
class RelativeActionsProcessorStep(ProcessorStep):
|
class RelativeActionsProcessorStep(ProcessorStep):
|
||||||
@@ -124,7 +162,14 @@ class RelativeActionsProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||||
state = observation.get(OBS_STATE) if observation else None
|
raw_state = observation.get(OBS_STATE) if observation else None
|
||||||
|
|
||||||
|
# When state_delta_indices loads multi-timestep state [B, n_obs, D],
|
||||||
|
# use only the current (last) timestep for relative action conversion.
|
||||||
|
if raw_state is not None:
|
||||||
|
state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||||
if state is not None:
|
if state is not None:
|
||||||
@@ -155,6 +200,120 @@ class RelativeActionsProcessorStep(ProcessorStep):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||||
|
"""Convert multi-timestep absolute state to relative (offset from current timestep).
|
||||||
|
|
||||||
|
Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims.
|
||||||
|
The last (current) timestep becomes zeros for masked dims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: (..., n_obs, state_dim) — last timestep is the reference (current).
|
||||||
|
mask: Which dims to convert. Can be shorter than state_dim.
|
||||||
|
"""
|
||||||
|
mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device)
|
||||||
|
dims = mask_t.shape[0]
|
||||||
|
current = state[..., -1:, :] # (..., 1, state_dim)
|
||||||
|
state = state.clone()
|
||||||
|
state[..., :dims] -= current[..., :dims] * mask_t
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("relative_state_processor")
|
||||||
|
@dataclass
|
||||||
|
class RelativeStateProcessorStep(ProcessorStep):
|
||||||
|
"""Converts observation.state to relative (offset from current timestep).
|
||||||
|
|
||||||
|
UMI-style relative proprioception: each state timestep is expressed as
|
||||||
|
an offset from the current EE pose, providing velocity information.
|
||||||
|
|
||||||
|
During training (multi-timestep input from ``state_delta_indices``):
|
||||||
|
``state[..., t, :] -= state[..., -1, :]`` — subtract current from all.
|
||||||
|
|
||||||
|
During inference (single timestep): buffers the previous state and stacks
|
||||||
|
``[previous, current]`` before applying the relative conversion, producing
|
||||||
|
the same ``[n_obs, D]`` shape the model expects.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: Whether to apply the relative conversion.
|
||||||
|
exclude_joints: Joint/dim names to keep absolute.
|
||||||
|
state_names: State dimension names from dataset metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
exclude_joints: list[str] = field(default_factory=list)
|
||||||
|
state_names: list[str] | None = None
|
||||||
|
_previous_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
|
def _build_mask(self, state_dim: int) -> list[bool]:
|
||||||
|
if not self.exclude_joints or self.state_names is None:
|
||||||
|
return [True] * state_dim
|
||||||
|
|
||||||
|
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||||
|
if not exclude_tokens:
|
||||||
|
return [True] * state_dim
|
||||||
|
|
||||||
|
mask = []
|
||||||
|
for name in self.state_names[:state_dim]:
|
||||||
|
state_name = str(name).lower()
|
||||||
|
is_excluded = any(token == state_name or token in state_name for token in exclude_tokens)
|
||||||
|
mask.append(not is_excluded)
|
||||||
|
|
||||||
|
if len(mask) < state_dim:
|
||||||
|
mask.extend([True] * (state_dim - len(mask)))
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
if not self.enabled:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
state = observation.get(OBS_STATE) if observation else None
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
|
||||||
|
mask = self._build_mask(state.shape[-1])
|
||||||
|
|
||||||
|
if state.ndim >= 3:
|
||||||
|
# [B, n_obs, D] — multi-timestep (training with state_delta_indices)
|
||||||
|
relative = to_relative_state(state, mask)
|
||||||
|
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D]
|
||||||
|
elif state.ndim == 2:
|
||||||
|
# [B, D] — single timestep (inference): buffer previous and stack
|
||||||
|
current = state
|
||||||
|
if self._previous_state is None:
|
||||||
|
self._previous_state = current.clone()
|
||||||
|
prev = self._previous_state
|
||||||
|
if prev.device != current.device or prev.dtype != current.dtype:
|
||||||
|
prev = prev.to(device=current.device, dtype=current.dtype)
|
||||||
|
stacked = torch.stack([prev, current], dim=-2) # [B, 2, D]
|
||||||
|
relative = to_relative_state(stacked, mask)
|
||||||
|
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D]
|
||||||
|
self._previous_state = current.clone()
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the state buffer. Call at episode boundaries during inference."""
|
||||||
|
self._previous_state = None
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"exclude_joints": self.exclude_joints,
|
||||||
|
"state_names": self.state_names,
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||||
|
|||||||
@@ -254,6 +254,10 @@ class RecomputeStatsConfig(OperationConfig):
|
|||||||
relative_exclude_joints: list[str] | None = None
|
relative_exclude_joints: list[str] | None = None
|
||||||
chunk_size: int = 50
|
chunk_size: int = 50
|
||||||
num_workers: int = 0
|
num_workers: int = 0
|
||||||
|
relative_state: bool = False
|
||||||
|
relative_exclude_state_joints: list[str] | None = None
|
||||||
|
state_obs_steps: int = 2
|
||||||
|
derive_state_from_action: bool = False
|
||||||
|
|
||||||
|
|
||||||
@OperationConfig.register_subclass("info")
|
@OperationConfig.register_subclass("info")
|
||||||
@@ -563,6 +567,14 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
|||||||
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
|
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
|
||||||
f"exclude_joints={cfg.operation.relative_exclude_joints})"
|
f"exclude_joints={cfg.operation.relative_exclude_joints})"
|
||||||
)
|
)
|
||||||
|
if cfg.operation.relative_state:
|
||||||
|
logging.info(
|
||||||
|
f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, "
|
||||||
|
f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.operation.derive_state_from_action:
|
||||||
|
logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)")
|
||||||
|
|
||||||
recompute_stats(
|
recompute_stats(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -571,6 +583,10 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
|||||||
relative_exclude_joints=cfg.operation.relative_exclude_joints,
|
relative_exclude_joints=cfg.operation.relative_exclude_joints,
|
||||||
chunk_size=cfg.operation.chunk_size,
|
chunk_size=cfg.operation.chunk_size,
|
||||||
num_workers=cfg.operation.num_workers,
|
num_workers=cfg.operation.num_workers,
|
||||||
|
relative_state=cfg.operation.relative_state,
|
||||||
|
relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints,
|
||||||
|
state_obs_steps=cfg.operation.state_obs_steps,
|
||||||
|
derive_state_from_action=cfg.operation.derive_state_from_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Stats written to {dataset.root}")
|
logging.info(f"Stats written to {dataset.root}")
|
||||||
|
|||||||
Reference in New Issue
Block a user