diff --git a/docs/source/umi_pi0_relative_ee.mdx b/docs/source/umi_pi0_relative_ee.mdx index 26a24d11c..26a7807d3 100644 --- a/docs/source/umi_pi0_relative_ee.mdx +++ b/docs/source/umi_pi0_relative_ee.mdx @@ -40,38 +40,57 @@ state[t] = action[t - offset] An offset of 1 is the typical UMI convention: at decision time the "current state" is where the gripper _already is_ (the result of the previous command), and the action is where it should go next. At episode boundaries where `t < offset`, we clamp to `action[0]`. -## Step 1: Add State and Recompute Stats +## Step 1: Add `observation.state` -The conversion script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` handles both steps. Edit the constants at the top of the file: +pi0 with `use_relative_actions=True` needs `observation.state` in the dataset to compute `action - state` on the fly. The script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` adds it. Edit the constants at the top: ```python HF_DATASET_ID = "/" + +# Option A: Copy an existing feature as observation.state +STATE_SOURCE_FEATURE = "observation.joints" # or "observation.pose", etc. + +# Option B: Derive from action with offset (set STATE_SOURCE_FEATURE = None) +STATE_SOURCE_FEATURE = None STATE_ACTION_OFFSET = 1 -RELATIVE_EXCLUDE_JOINTS = ["gripper"] -CHUNK_SIZE = 50 ``` +**Choosing the state source:** + +- If your dataset already has a feature in the same space as `action` (e.g. `observation.joints` for joint-space actions, or `observation.pose` for EE-space actions), set `STATE_SOURCE_FEATURE` to copy it. +- If your dataset only has a single trajectory (like raw UMI EE data where action = the EE poses), set `STATE_SOURCE_FEATURE = None` and use `STATE_ACTION_OFFSET` to derive state from the action column with a time offset. + +`observation.state` **must have the same shape as `action`** — the relative conversion computes `action - state` element-wise. + Then run: ```bash python examples/umi_pi0_relative_ee/convert_umi_dataset.py ``` -This: - -- Loads your existing UMI LeRobot dataset. -- Adds `observation.state` derived from the `action` column with the configured offset. -- Calls `recompute_stats(relative_action=True)` to compute chunk-level relative action statistics. - -The `RELATIVE_EXCLUDE_JOINTS` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. - -If your dataset already has `observation.state`, the script skips the feature-adding step and only recomputes relative action statistics. +If your dataset already has `observation.state`, the script exits early — nothing to do. -## Step 2: Train +## Step 2: Recompute Relative Action Stats + +Use the built-in CLI to recompute dataset statistics in relative space: + +```bash +lerobot-edit-dataset \ + --repo_id \ + --operation.type recompute_stats \ + --operation.relative_action true \ + --operation.chunk_size 50 \ + --operation.relative_exclude_joints "['gripper']" \ + --push_to_hub true +``` + +The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative. + +## Step 3: Train No custom training script is needed — standard `lerobot-train` handles everything: @@ -92,7 +111,7 @@ Under the hood, the training pipeline: See the [pi0 documentation](pi0) for all available training options. -## Step 3: Evaluate +## Step 4: Evaluate The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space): diff --git a/examples/umi_pi0_relative_ee/convert_umi_dataset.py b/examples/umi_pi0_relative_ee/convert_umi_dataset.py index 17bb5ffd9..74e70ea4d 100644 --- a/examples/umi_pi0_relative_ee/convert_umi_dataset.py +++ b/examples/umi_pi0_relative_ee/convert_umi_dataset.py @@ -15,35 +15,31 @@ # limitations under the License. """ -Add ``observation.state`` to an existing UMI LeRobot dataset and recompute -stats for pi0 training with relative EE actions. +Add ``observation.state`` to an existing LeRobot dataset. -UMI datasets already contain ``action`` (absolute EE pose from SLAM) and -images. This script derives ``observation.state`` from the action column -and recomputes statistics with relative action stats. +pi0 with ``use_relative_actions=True`` requires ``observation.state`` to +compute relative actions (action − state) on the fly. This script adds +that feature when it doesn't already exist. -State-Action Offset: -UMI SLAM produces a single trajectory of EE poses stored as ``action``. -We derive ``observation.state`` from the same trajectory with a -configurable offset: +Two modes for deriving ``observation.state``: - state[t] = action[t - STATE_ACTION_OFFSET] + 1. **From an existing feature** (``STATE_SOURCE_FEATURE``): + Copies an existing column (e.g. ``observation.joints`` or + ``observation.pose``) to ``observation.state``. -With offset=0, state equals action at the same timestep. With offset=1, -state is the previous timestep's action — representing where the gripper -*arrived* (the result of the previous command), which is what the robot -knows at decision time. Offset=1 is the typical UMI convention. + 2. **From action with offset** (``STATE_SOURCE_FEATURE = None``): + Derives state from the action column with a per-episode offset: + state[t] = action[t - STATE_ACTION_OFFSET] -For the first frame(s) of each episode where t < offset, we use the -earliest available action (action[0]) as state. +After running this script, recompute relative action stats via the CLI: -After adding state, train with standard lerobot-train: - lerobot-train \\ - --dataset.repo_id= \\ - --policy.type=pi0 \\ - --policy.use_relative_actions=true \\ - --policy.relative_exclude_joints='["gripper"]' \\ - --policy.pretrained_path=lerobot/pi0_base + lerobot-edit-dataset \\ + --repo_id \\ + --operation.type recompute_stats \\ + --operation.relative_action true \\ + --operation.chunk_size 50 \\ + --operation.relative_exclude_joints "['gripper']" \\ + --push_to_hub true Usage: python convert_umi_dataset.py @@ -52,63 +48,79 @@ Usage: from __future__ import annotations import logging +from collections.abc import Callable import numpy as np -from lerobot.datasets.dataset_tools import add_features, recompute_stats +from lerobot.datasets.dataset_tools import add_features from lerobot.datasets.lerobot_dataset import LeRobotDataset logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# ── Configuration ───────────────────────────────────────────────────────── -HF_DATASET_ID = "/" +HF_DATASET_ID = "" -# Offset between state and action indices within each episode. +# Source for observation.state. Options: +# - A feature name (e.g. "observation.joints", "observation.pose") to copy +# an existing column. Must have the same shape as "action". +# - None to derive state from action with STATE_ACTION_OFFSET. +STATE_SOURCE_FEATURE: str | None = "observation.joints" + +# Only used when STATE_SOURCE_FEATURE is None. # 0 → state[t] = action[t] (same instant) -# 1 → state[t] = action[t-1] (state lags by 1 step — typical for UMI) +# 1 → state[t] = action[t-1] (state lags by 1 step) STATE_ACTION_OFFSET = 1 -# Joint names to keep absolute (not converted to relative). -RELATIVE_EXCLUDE_JOINTS: list[str] = ["gripper"] - -# pi0 chunk size (for relative stats computation). -CHUNK_SIZE = 50 +# Push the augmented dataset to the Hugging Face Hub. +PUSH_TO_HUB = True -# ── Build state from action with offset ────────────────────────────────── +def _build_state_from_feature(dataset: LeRobotDataset, source_feature: str) -> Callable: + """Return a callable that copies values from an existing feature.""" + hf = dataset.hf_dataset + source_values = hf[source_feature] + + episode_indices = np.array(hf["episode_index"]) + frame_indices = np.array(hf["frame_index"]) + key_to_global = {(int(episode_indices[i]), int(frame_indices[i])): i for i in range(len(episode_indices))} + + def _get_state(row_dict: dict, ep_idx: int, frame_idx: int): + return source_values[key_to_global[(ep_idx, frame_idx)]] + + return _get_state -def build_state_array(dataset: LeRobotDataset, offset: int) -> np.ndarray: - """Derive observation.state from the action column with a per-episode offset. +def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable: + """Return a callable that derives state from action with a per-episode offset. - For each frame t in an episode: - state[t] = action[max(0, t - offset)] (clamped to episode start) + state[t] = action[max(0, t - offset)] (clamped to episode start) """ hf = dataset.hf_dataset - actions = np.array(hf["action"], dtype=np.float32) episode_indices = np.array(hf["episode_index"]) frame_indices = np.array(hf["frame_index"]) - states = np.empty_like(actions) - + ep_sorted: dict[int, list[tuple[int, int]]] = {} for ep_idx in np.unique(episode_indices): ep_mask = episode_indices == ep_idx - ep_global_indices = np.where(ep_mask)[0] - ep_actions = actions[ep_global_indices] - ep_frames = frame_indices[ep_global_indices] + ep_globals = np.where(ep_mask)[0] + ep_frames = frame_indices[ep_globals] + order = np.argsort(ep_frames) + ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order] - sort_order = np.argsort(ep_frames) - ep_global_indices = ep_global_indices[sort_order] - ep_actions = ep_actions[sort_order] + ep_frame_to_local: dict[int, dict[int, int]] = {} + for ep_idx, sorted_list in ep_sorted.items(): + ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)} - n = len(ep_actions) - for local_t in range(n): - source_t = max(0, local_t - offset) - states[ep_global_indices[local_t]] = ep_actions[source_t] + actions = hf["action"] - return states + def _get_state(row_dict: dict, ep_idx: int, frame_idx: int): + local_t = ep_frame_to_local[ep_idx][frame_idx] + source_local = max(0, local_t - offset) + _, source_global = ep_sorted[ep_idx][source_local] + return actions[source_global] + + return _get_state def main(): @@ -116,44 +128,54 @@ def main(): dataset = LeRobotDataset(HF_DATASET_ID) if "observation.state" in dataset.features: - logger.warning("observation.state already exists — skipping add_features") - augmented = dataset - else: - logger.info(f"Building observation.state from action with offset={STATE_ACTION_OFFSET}") - state_array = build_state_array(dataset, offset=STATE_ACTION_OFFSET) + logger.info("observation.state already exists — nothing to do") + return - action_meta = dataset.features["action"] + action_meta = dataset.features["action"] + logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}") + + if STATE_SOURCE_FEATURE is not None: + if STATE_SOURCE_FEATURE not in dataset.features: + raise ValueError( + f"Source feature '{STATE_SOURCE_FEATURE}' not found. " + f"Available: {list(dataset.features.keys())}" + ) + source_meta = dataset.features[STATE_SOURCE_FEATURE] + logger.info(f"Copying {STATE_SOURCE_FEATURE} → observation.state") + state_fn = _build_state_from_feature(dataset, STATE_SOURCE_FEATURE) + state_feature_info = { + "dtype": "float32", + "shape": list(source_meta["shape"]), + "names": source_meta.get("names"), + } + else: + logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}") + state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET) state_feature_info = { "dtype": "float32", "shape": list(action_meta["shape"]), "names": action_meta.get("names"), } - augmented = add_features( - dataset, - features={ - "observation.state": (state_array, state_feature_info), - }, - ) - logger.info("observation.state added") - - logger.info("Recomputing stats with relative action statistics...") - recompute_stats( - augmented, - relative_action=True, - relative_exclude_joints=RELATIVE_EXCLUDE_JOINTS, - chunk_size=CHUNK_SIZE, + augmented = add_features( + dataset, + features={"observation.state": (state_fn, state_feature_info)}, ) + logger.info("observation.state added") + + if PUSH_TO_HUB: + logger.info(f"Pushing to Hub: {augmented.repo_id}") + augmented.push_to_hub() - logger.info(f"Dataset ready at {augmented.root}") logger.info( - "Train with:\n" - " lerobot-train \\\n" - f" --dataset.repo_id={augmented.repo_id} \\\n" - " --policy.type=pi0 \\\n" - " --policy.use_relative_actions=true \\\n" - f" --policy.relative_exclude_joints='{RELATIVE_EXCLUDE_JOINTS}' \\\n" - " --policy.pretrained_path=lerobot/pi0_base" + f"Done. Now recompute relative action stats:\n" + " lerobot-edit-dataset \\\n" + f" --repo_id {augmented.repo_id} \\\n" + " --operation.type recompute_stats \\\n" + " --operation.relative_action true \\\n" + " --operation.chunk_size 50 \\\n" + " --operation.relative_exclude_joints \"['gripper']\" \\\n" + " --push_to_hub true" )