refactor to use relative state

This commit is contained in:
Pepijn
2026-04-01 17:23:58 +02:00
parent 0fc855df13
commit 58bd11caf3
14 changed files with 502 additions and 296 deletions
@@ -1,220 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Add ``observation.state`` to an existing LeRobot dataset.
pi0 uses ``observation.state`` as its proprioceptive input AND for
relative action conversion (action state). This script creates
``observation.state`` by concatenating one or more existing features.
Ordering matters: the features whose dimensions correspond to ``action``
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
pose) are appended after and are seen by the model but not used for
relative conversion.
Example: action = [proximal, distal], and we want the model to also see
the EE pose:
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
The relative conversion uses state[:2] = [proximal, distal] to subtract
from action[:2], and the model sees all 8 dimensions.
After running this script, recompute relative action stats:
lerobot-edit-dataset \\
--repo_id <your_dataset> \\
--operation.type recompute_stats \\
--operation.relative_action true \\
--operation.chunk_size 50 \\
--operation.relative_exclude_joints "[]" \\
--push_to_hub true
Usage:
python convert_umi_dataset.py
"""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HF_DATASET_ID = ""
# Output repo ID. Set to None for default "<input>_modified".
OUTPUT_REPO_ID: str | None = None
# Features to concatenate into observation.state. Order matters:
# action-matching features FIRST, then extra proprioception.
# Set to a single string to copy one feature directly.
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
# Only used when STATE_SOURCE_FEATURES is None:
# derive state from action with a per-episode offset.
STATE_ACTION_OFFSET = 1
# Push the augmented dataset to the Hugging Face Hub.
PUSH_TO_HUB = True
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
hf = dataset.hf_dataset
ep = np.array(hf["episode_index"])
fr = np.array(hf["frame_index"])
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
"""Concatenate multiple features into observation.state."""
hf = dataset.hf_dataset
key_to_global = _build_global_index(dataset)
columns = [hf[feat] for feat in source_features]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
g = key_to_global[(ep_idx, frame_idx)]
parts = []
for col in columns:
val = col[g]
if hasattr(val, "tolist"):
flat = val.tolist()
if isinstance(flat, list):
parts.extend(flat)
else:
parts.append(flat)
elif isinstance(val, list):
parts.extend(val)
else:
parts.append(float(val))
return parts
return _get_state
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
"""Derive state from action with a per-episode offset."""
hf = dataset.hf_dataset
episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"])
ep_sorted: dict[int, list[tuple[int, int]]] = {}
for ep_idx in np.unique(episode_indices):
ep_mask = episode_indices == ep_idx
ep_globals = np.where(ep_mask)[0]
ep_frames = frame_indices[ep_globals]
order = np.argsort(ep_frames)
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
ep_frame_to_local: dict[int, dict[int, int]] = {}
for ep_idx, sorted_list in ep_sorted.items():
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
actions = hf["action"]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
local_t = ep_frame_to_local[ep_idx][frame_idx]
source_local = max(0, local_t - offset)
_, source_global = ep_sorted[ep_idx][source_local]
return actions[source_global]
return _get_state
def main():
logger.info(f"Loading dataset {HF_DATASET_ID}")
dataset = LeRobotDataset(HF_DATASET_ID)
if "observation.state" in dataset.features:
logger.info("observation.state already exists — nothing to do")
return
action_meta = dataset.features["action"]
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
if STATE_SOURCE_FEATURES is not None:
source_list = (
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
)
for feat in source_list:
if feat not in dataset.features:
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
# Compute combined shape and names
total_dim = 0
all_names = []
for feat in source_list:
meta = dataset.features[feat]
total_dim += meta["shape"][0]
names = meta.get("names")
if names:
all_names.extend(names)
logger.info(
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
)
state_fn = _build_state_from_features(dataset, source_list)
state_feature_info = {
"dtype": "float32",
"shape": [total_dim],
"names": all_names or None,
}
else:
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
state_feature_info = {
"dtype": "float32",
"shape": list(action_meta["shape"]),
"names": action_meta.get("names"),
}
augmented = add_features(
dataset,
features={"observation.state": (state_fn, state_feature_info)},
repo_id=OUTPUT_REPO_ID,
)
logger.info("observation.state added")
if PUSH_TO_HUB:
logger.info(f"Pushing to Hub: {augmented.repo_id}")
augmented.push_to_hub()
logger.info(
f"Done. Dataset at: {augmented.root}\n"
"Now recompute relative action stats:\n"
" lerobot-edit-dataset \\\n"
f" --repo_id {augmented.repo_id} \\\n"
" --operation.type recompute_stats \\\n"
" --operation.relative_action true \\\n"
" --operation.chunk_size 50 \\\n"
' --operation.relative_exclude_joints "[]" \\\n'
" --push_to_hub true"
)
if __name__ == "__main__":
main()
+23 -7
View File
@@ -17,17 +17,20 @@
"""
Inference script for a pi0 model trained with **relative EE actions**.
This uses the built-in ``RelativeActionsProcessorStep`` and
``AbsoluteActionsProcessorStep`` that are already wired into pi0's
processor pipeline when ``use_relative_actions=True``.
This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
``RelativeStateProcessorStep`` that are already wired into pi0's processor
pipeline.
The inference loop:
1. Reads joint positions from the robot.
2. Converts to EE pose via forward kinematics (FK).
This produces ``observation.state`` with the current EE pose.
3. The pi0 preprocessor:
a) ``RelativeActionsProcessorStep`` caches the raw state.
b) ``NormalizerProcessorStep`` normalizes state and actions.
a) ``DeriveStateFromActionStep`` — no-op (state comes from robot).
b) ``RelativeActionsProcessorStep`` caches the raw state.
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
d) ``NormalizerProcessorStep`` normalizes state and actions.
4. pi0 predicts relative action chunk.
5. The pi0 postprocessor:
a) ``UnnormalizerProcessorStep`` unnormalizes.
@@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.processor import (
RelativeStateProcessorStep,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
@@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Latency compensation: skip this many steps from the start of each predicted
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
LATENCY_SKIP_STEPS = 0
# EE feature keys produced by ForwardKinematicsJointsToEE
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
@@ -94,6 +103,7 @@ def main():
robot = SO100Follower(robot_config)
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
@@ -151,9 +161,8 @@ def main():
# Build pre/post processors from the trained model.
# The pi0 processor pipeline already includes:
# pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep
# pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
# These handle the relative ↔ absolute conversion automatically.
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
@@ -161,6 +170,9 @@ def main():
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Find the relative state step (if present) so we can reset its buffer between episodes.
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
robot.connect()
listener, events = init_keyboard_listener()
@@ -174,6 +186,10 @@ def main():
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Reset relative state buffer so velocity is zero at episode start
for step in _relative_state_steps:
step.reset()
record_loop(
robot=robot,
events=events,