mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
fixes, do stats in seperate script (existing)
This commit is contained in:
@@ -15,35 +15,31 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Add ``observation.state`` to an existing UMI LeRobot dataset and recompute
|
||||
stats for pi0 training with relative EE actions.
|
||||
Add ``observation.state`` to an existing LeRobot dataset.
|
||||
|
||||
UMI datasets already contain ``action`` (absolute EE pose from SLAM) and
|
||||
images. This script derives ``observation.state`` from the action column
|
||||
and recomputes statistics with relative action stats.
|
||||
pi0 with ``use_relative_actions=True`` requires ``observation.state`` to
|
||||
compute relative actions (action − state) on the fly. This script adds
|
||||
that feature when it doesn't already exist.
|
||||
|
||||
State-Action Offset:
|
||||
UMI SLAM produces a single trajectory of EE poses stored as ``action``.
|
||||
We derive ``observation.state`` from the same trajectory with a
|
||||
configurable offset:
|
||||
Two modes for deriving ``observation.state``:
|
||||
|
||||
state[t] = action[t - STATE_ACTION_OFFSET]
|
||||
1. **From an existing feature** (``STATE_SOURCE_FEATURE``):
|
||||
Copies an existing column (e.g. ``observation.joints`` or
|
||||
``observation.pose``) to ``observation.state``.
|
||||
|
||||
With offset=0, state equals action at the same timestep. With offset=1,
|
||||
state is the previous timestep's action — representing where the gripper
|
||||
*arrived* (the result of the previous command), which is what the robot
|
||||
knows at decision time. Offset=1 is the typical UMI convention.
|
||||
2. **From action with offset** (``STATE_SOURCE_FEATURE = None``):
|
||||
Derives state from the action column with a per-episode offset:
|
||||
state[t] = action[t - STATE_ACTION_OFFSET]
|
||||
|
||||
For the first frame(s) of each episode where t < offset, we use the
|
||||
earliest available action (action[0]) as state.
|
||||
After running this script, recompute relative action stats via the CLI:
|
||||
|
||||
After adding state, train with standard lerobot-train:
|
||||
lerobot-train \\
|
||||
--dataset.repo_id=<your_dataset> \\
|
||||
--policy.type=pi0 \\
|
||||
--policy.use_relative_actions=true \\
|
||||
--policy.relative_exclude_joints='["gripper"]' \\
|
||||
--policy.pretrained_path=lerobot/pi0_base
|
||||
lerobot-edit-dataset \\
|
||||
--repo_id <your_dataset> \\
|
||||
--operation.type recompute_stats \\
|
||||
--operation.relative_action true \\
|
||||
--operation.chunk_size 50 \\
|
||||
--operation.relative_exclude_joints "['gripper']" \\
|
||||
--push_to_hub true
|
||||
|
||||
Usage:
|
||||
python convert_umi_dataset.py
|
||||
@@ -52,63 +48,79 @@ Usage:
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import add_features, recompute_stats
|
||||
from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Configuration ─────────────────────────────────────────────────────────
|
||||
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
HF_DATASET_ID = ""
|
||||
|
||||
# Offset between state and action indices within each episode.
|
||||
# Source for observation.state. Options:
|
||||
# - A feature name (e.g. "observation.joints", "observation.pose") to copy
|
||||
# an existing column. Must have the same shape as "action".
|
||||
# - None to derive state from action with STATE_ACTION_OFFSET.
|
||||
STATE_SOURCE_FEATURE: str | None = "observation.joints"
|
||||
|
||||
# Only used when STATE_SOURCE_FEATURE is None.
|
||||
# 0 → state[t] = action[t] (same instant)
|
||||
# 1 → state[t] = action[t-1] (state lags by 1 step — typical for UMI)
|
||||
# 1 → state[t] = action[t-1] (state lags by 1 step)
|
||||
STATE_ACTION_OFFSET = 1
|
||||
|
||||
# Joint names to keep absolute (not converted to relative).
|
||||
RELATIVE_EXCLUDE_JOINTS: list[str] = ["gripper"]
|
||||
|
||||
# pi0 chunk size (for relative stats computation).
|
||||
CHUNK_SIZE = 50
|
||||
# Push the augmented dataset to the Hugging Face Hub.
|
||||
PUSH_TO_HUB = True
|
||||
|
||||
|
||||
# ── Build state from action with offset ──────────────────────────────────
|
||||
def _build_state_from_feature(dataset: LeRobotDataset, source_feature: str) -> Callable:
|
||||
"""Return a callable that copies values from an existing feature."""
|
||||
hf = dataset.hf_dataset
|
||||
source_values = hf[source_feature]
|
||||
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
frame_indices = np.array(hf["frame_index"])
|
||||
key_to_global = {(int(episode_indices[i]), int(frame_indices[i])): i for i in range(len(episode_indices))}
|
||||
|
||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||
return source_values[key_to_global[(ep_idx, frame_idx)]]
|
||||
|
||||
return _get_state
|
||||
|
||||
|
||||
def build_state_array(dataset: LeRobotDataset, offset: int) -> np.ndarray:
|
||||
"""Derive observation.state from the action column with a per-episode offset.
|
||||
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
|
||||
"""Return a callable that derives state from action with a per-episode offset.
|
||||
|
||||
For each frame t in an episode:
|
||||
state[t] = action[max(0, t - offset)] (clamped to episode start)
|
||||
state[t] = action[max(0, t - offset)] (clamped to episode start)
|
||||
"""
|
||||
hf = dataset.hf_dataset
|
||||
actions = np.array(hf["action"], dtype=np.float32)
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
frame_indices = np.array(hf["frame_index"])
|
||||
|
||||
states = np.empty_like(actions)
|
||||
|
||||
ep_sorted: dict[int, list[tuple[int, int]]] = {}
|
||||
for ep_idx in np.unique(episode_indices):
|
||||
ep_mask = episode_indices == ep_idx
|
||||
ep_global_indices = np.where(ep_mask)[0]
|
||||
ep_actions = actions[ep_global_indices]
|
||||
ep_frames = frame_indices[ep_global_indices]
|
||||
ep_globals = np.where(ep_mask)[0]
|
||||
ep_frames = frame_indices[ep_globals]
|
||||
order = np.argsort(ep_frames)
|
||||
ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order]
|
||||
|
||||
sort_order = np.argsort(ep_frames)
|
||||
ep_global_indices = ep_global_indices[sort_order]
|
||||
ep_actions = ep_actions[sort_order]
|
||||
ep_frame_to_local: dict[int, dict[int, int]] = {}
|
||||
for ep_idx, sorted_list in ep_sorted.items():
|
||||
ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)}
|
||||
|
||||
n = len(ep_actions)
|
||||
for local_t in range(n):
|
||||
source_t = max(0, local_t - offset)
|
||||
states[ep_global_indices[local_t]] = ep_actions[source_t]
|
||||
actions = hf["action"]
|
||||
|
||||
return states
|
||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||
local_t = ep_frame_to_local[ep_idx][frame_idx]
|
||||
source_local = max(0, local_t - offset)
|
||||
_, source_global = ep_sorted[ep_idx][source_local]
|
||||
return actions[source_global]
|
||||
|
||||
return _get_state
|
||||
|
||||
|
||||
def main():
|
||||
@@ -116,44 +128,54 @@ def main():
|
||||
dataset = LeRobotDataset(HF_DATASET_ID)
|
||||
|
||||
if "observation.state" in dataset.features:
|
||||
logger.warning("observation.state already exists — skipping add_features")
|
||||
augmented = dataset
|
||||
else:
|
||||
logger.info(f"Building observation.state from action with offset={STATE_ACTION_OFFSET}")
|
||||
state_array = build_state_array(dataset, offset=STATE_ACTION_OFFSET)
|
||||
logger.info("observation.state already exists — nothing to do")
|
||||
return
|
||||
|
||||
action_meta = dataset.features["action"]
|
||||
action_meta = dataset.features["action"]
|
||||
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
|
||||
|
||||
if STATE_SOURCE_FEATURE is not None:
|
||||
if STATE_SOURCE_FEATURE not in dataset.features:
|
||||
raise ValueError(
|
||||
f"Source feature '{STATE_SOURCE_FEATURE}' not found. "
|
||||
f"Available: {list(dataset.features.keys())}"
|
||||
)
|
||||
source_meta = dataset.features[STATE_SOURCE_FEATURE]
|
||||
logger.info(f"Copying {STATE_SOURCE_FEATURE} → observation.state")
|
||||
state_fn = _build_state_from_feature(dataset, STATE_SOURCE_FEATURE)
|
||||
state_feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": list(source_meta["shape"]),
|
||||
"names": source_meta.get("names"),
|
||||
}
|
||||
else:
|
||||
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
|
||||
state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET)
|
||||
state_feature_info = {
|
||||
"dtype": "float32",
|
||||
"shape": list(action_meta["shape"]),
|
||||
"names": action_meta.get("names"),
|
||||
}
|
||||
|
||||
augmented = add_features(
|
||||
dataset,
|
||||
features={
|
||||
"observation.state": (state_array, state_feature_info),
|
||||
},
|
||||
)
|
||||
logger.info("observation.state added")
|
||||
|
||||
logger.info("Recomputing stats with relative action statistics...")
|
||||
recompute_stats(
|
||||
augmented,
|
||||
relative_action=True,
|
||||
relative_exclude_joints=RELATIVE_EXCLUDE_JOINTS,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
augmented = add_features(
|
||||
dataset,
|
||||
features={"observation.state": (state_fn, state_feature_info)},
|
||||
)
|
||||
logger.info("observation.state added")
|
||||
|
||||
if PUSH_TO_HUB:
|
||||
logger.info(f"Pushing to Hub: {augmented.repo_id}")
|
||||
augmented.push_to_hub()
|
||||
|
||||
logger.info(f"Dataset ready at {augmented.root}")
|
||||
logger.info(
|
||||
"Train with:\n"
|
||||
" lerobot-train \\\n"
|
||||
f" --dataset.repo_id={augmented.repo_id} \\\n"
|
||||
" --policy.type=pi0 \\\n"
|
||||
" --policy.use_relative_actions=true \\\n"
|
||||
f" --policy.relative_exclude_joints='{RELATIVE_EXCLUDE_JOINTS}' \\\n"
|
||||
" --policy.pretrained_path=lerobot/pi0_base"
|
||||
f"Done. Now recompute relative action stats:\n"
|
||||
" lerobot-edit-dataset \\\n"
|
||||
f" --repo_id {augmented.repo_id} \\\n"
|
||||
" --operation.type recompute_stats \\\n"
|
||||
" --operation.relative_action true \\\n"
|
||||
" --operation.chunk_size 50 \\\n"
|
||||
" --operation.relative_exclude_joints \"['gripper']\" \\\n"
|
||||
" --push_to_hub true"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user