mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
refactor to use relative state
This commit is contained in:
@@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Add ``observation.state`` to an existing LeRobot dataset.
|
||||
|
||||
pi0 uses ``observation.state`` as its proprioceptive input AND for
|
||||
relative action conversion (action − state). This script creates
|
||||
``observation.state`` by concatenating one or more existing features.
|
||||
|
||||
Ordering matters: the features whose dimensions correspond to ``action``
|
||||
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
|
||||
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
|
||||
pose) are appended after and are seen by the model but not used for
|
||||
relative conversion.
|
||||
|
||||
Example: action = [proximal, distal], and we want the model to also see
|
||||
the EE pose:
|
||||
|
||||
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
|
||||
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
|
||||
|
||||
The relative conversion uses state[:2] = [proximal, distal] to subtract
|
||||
from action[:2], and the model sees all 8 dimensions.
|
||||
|
||||
After running this script, recompute relative action stats:
|
||||
|
||||
lerobot-edit-dataset \\
|
||||
--repo_id <your_dataset> \\
|
||||
--operation.type recompute_stats \\
|
||||
--operation.relative_action true \\
|
||||
--operation.chunk_size 50 \\
|
||||
--operation.relative_exclude_joints "[]" \\
|
||||
--push_to_hub true
|
||||
|
||||
Usage:
|
||||
python convert_umi_dataset.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HF_DATASET_ID = ""
|
||||
|
||||
# Output repo ID. Set to None for default "<input>_modified".
|
||||
OUTPUT_REPO_ID: str | None = None
|
||||
|
||||
# Features to concatenate into observation.state. Order matters:
|
||||
# action-matching features FIRST, then extra proprioception.
|
||||
# Set to a single string to copy one feature directly.
|
||||
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
|
||||
|
||||
# Only used when STATE_SOURCE_FEATURES is None:
|
||||
# derive state from action with a per-episode offset.
|
||||
STATE_ACTION_OFFSET = 1
|
||||
|
||||
# Push the augmented dataset to the Hugging Face Hub.
|
||||
PUSH_TO_HUB = True
|
||||
|
||||
|
||||
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
|
||||
hf = dataset.hf_dataset
|
||||
ep = np.array(hf["episode_index"])
|
||||
fr = np.array(hf["frame_index"])
|
||||
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
|
||||
|
||||
|
||||
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
|
||||
"""Concatenate multiple features into observation.state."""
|
||||
hf = dataset.hf_dataset
|
||||
key_to_global = _build_global_index(dataset)
|
||||
|
||||
columns = [hf[feat] for feat in source_features]
|
||||
|
||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||
g = key_to_global[(ep_idx, frame_idx)]
|
||||
parts = []
|
||||
for col in columns:
|
||||
val = col[g]
|
||||
if hasattr(val, "tolist"):
|
||||
flat = val.tolist()
|
||||
if isinstance(flat, list):
|
||||
parts.extend(flat)
|
||||
else:
|
||||
parts.append(flat)
|
||||
elif isinstance(val, list):
|
||||
parts.extend(val)
|
||||
else:
|
||||
parts.append(float(val))
|
||||
return parts
|
||||
|
||||
return _get_state
|
||||
|
||||
|
||||
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
|
||||
"""Derive state from action with a per-episode offset."""
|
||||
hf = dataset.hf_dataset
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
frame_indices = np.array(hf["frame_index"])
|
||||
|
||||
ep_sorted: dict[int, list[tuple[int, int]]] = {}
|
||||
for ep_idx in np.unique(episode_indices):
|
||||
ep_mask = episode_indices == ep_idx
|
||||
ep_globals = np.where(ep_mask)[0]
|
||||
ep_frames = frame_indices[ep_globals]
|
||||
order = np.argsort(ep_frames)
|
||||
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
|
||||
|
||||
ep_frame_to_local: dict[int, dict[int, int]] = {}
|
||||
for ep_idx, sorted_list in ep_sorted.items():
|
||||
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
|
||||
|
||||
actions = hf["action"]
|
||||
|
||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||
local_t = ep_frame_to_local[ep_idx][frame_idx]
|
||||
source_local = max(0, local_t - offset)
|
||||
_, source_global = ep_sorted[ep_idx][source_local]
|
||||
return actions[source_global]
|
||||
|
||||
return _get_state
|
||||
|
||||
|
||||
def main():
|
||||
logger.info(f"Loading dataset {HF_DATASET_ID}")
|
||||
dataset = LeRobotDataset(HF_DATASET_ID)
|
||||
|
||||
if "observation.state" in dataset.features:
|
||||
logger.info("observation.state already exists — nothing to do")
|
||||
return
|
||||
|
||||
action_meta = dataset.features["action"]
|
||||
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
|
||||
|
||||
if STATE_SOURCE_FEATURES is not None:
|
||||
source_list = (
|
||||
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
|
||||
)
|
||||
for feat in source_list:
|
||||
if feat not in dataset.features:
|
||||
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
|
||||
|
||||
# Compute combined shape and names
|
||||
total_dim = 0
|
||||
all_names = []
|
||||
for feat in source_list:
|
||||
meta = dataset.features[feat]
|
||||
total_dim += meta["shape"][0]
|
||||
names = meta.get("names")
|
||||
if names:
|
||||
all_names.extend(names)
|
||||
|
||||
logger.info(
|
||||
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
|
||||
)
|
||||
state_fn = _build_state_from_features(dataset, source_list)
|
||||
state_feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": [total_dim],
|
||||
"names": all_names or None,
|
||||
}
|
||||
else:
|
||||
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
|
||||
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
|
||||
state_feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": list(action_meta["shape"]),
|
||||
"names": action_meta.get("names"),
|
||||
}
|
||||
|
||||
augmented = add_features(
|
||||
dataset,
|
||||
features={"observation.state": (state_fn, state_feature_info)},
|
||||
repo_id=OUTPUT_REPO_ID,
|
||||
)
|
||||
logger.info("observation.state added")
|
||||
|
||||
if PUSH_TO_HUB:
|
||||
logger.info(f"Pushing to Hub: {augmented.repo_id}")
|
||||
augmented.push_to_hub()
|
||||
|
||||
logger.info(
|
||||
f"Done. Dataset at: {augmented.root}\n"
|
||||
"Now recompute relative action stats:\n"
|
||||
" lerobot-edit-dataset \\\n"
|
||||
f" --repo_id {augmented.repo_id} \\\n"
|
||||
" --operation.type recompute_stats \\\n"
|
||||
" --operation.relative_action true \\\n"
|
||||
" --operation.chunk_size 50 \\\n"
|
||||
' --operation.relative_exclude_joints "[]" \\\n'
|
||||
" --push_to_hub true"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -17,17 +17,20 @@
|
||||
"""
|
||||
Inference script for a pi0 model trained with **relative EE actions**.
|
||||
|
||||
This uses the built-in ``RelativeActionsProcessorStep`` and
|
||||
``AbsoluteActionsProcessorStep`` that are already wired into pi0's
|
||||
processor pipeline when ``use_relative_actions=True``.
|
||||
This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
|
||||
``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
|
||||
``RelativeStateProcessorStep`` that are already wired into pi0's processor
|
||||
pipeline.
|
||||
|
||||
The inference loop:
|
||||
1. Reads joint positions from the robot.
|
||||
2. Converts to EE pose via forward kinematics (FK).
|
||||
This produces ``observation.state`` with the current EE pose.
|
||||
3. The pi0 preprocessor:
|
||||
a) ``RelativeActionsProcessorStep`` caches the raw state.
|
||||
b) ``NormalizerProcessorStep`` normalizes state and actions.
|
||||
a) ``DeriveStateFromActionStep`` — no-op (state comes from robot).
|
||||
b) ``RelativeActionsProcessorStep`` caches the raw state.
|
||||
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
|
||||
d) ``NormalizerProcessorStep`` normalizes state and actions.
|
||||
4. pi0 predicts relative action chunk.
|
||||
5. The pi0 postprocessor:
|
||||
a) ``UnnormalizerProcessorStep`` unnormalizes.
|
||||
@@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.processor import (
|
||||
RelativeStateProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
@@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Latency compensation: skip this many steps from the start of each predicted
|
||||
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
|
||||
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
|
||||
LATENCY_SKIP_STEPS = 0
|
||||
|
||||
# EE feature keys produced by ForwardKinematicsJointsToEE
|
||||
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
|
||||
@@ -94,6 +103,7 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
|
||||
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
@@ -151,9 +161,8 @@ def main():
|
||||
|
||||
# Build pre/post processors from the trained model.
|
||||
# The pi0 processor pipeline already includes:
|
||||
# pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep
|
||||
# pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep
|
||||
# post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ...
|
||||
# These handle the relative ↔ absolute conversion automatically.
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
@@ -161,6 +170,9 @@ def main():
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Find the relative state step (if present) so we can reset its buffer between episodes.
|
||||
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
|
||||
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
@@ -174,6 +186,10 @@ def main():
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Reset relative state buffer so velocity is zero at episode start
|
||||
for step in _relative_state_steps:
|
||||
step.reset()
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
|
||||
Reference in New Issue
Block a user