diff --git a/docs/source/action_representations.mdx b/docs/source/action_representations.mdx index 1604ed467..ba73196de 100644 --- a/docs/source/action_representations.mdx +++ b/docs/source/action_representations.mdx @@ -202,11 +202,22 @@ Here is how the different processors compose. Each arrow is a processor step, an └─────────────────────────────────────────┘ ┌─────────────────────────────────────────┐ - Representation │ Absolute ←────→ Relative │ + State Derivation │ Action column ────→ State + Action │ + │ DeriveStateFromActionStep (pre only) │ + │ (UMI-style: state from action chunk) │ + └─────────────────────────────────────────┘ + + ┌─────────────────────────────────────────┐ + Action Repr. │ Absolute ←────→ Relative │ │ RelativeActionsProcessorStep (pre) │ │ AbsoluteActionsProcessorStep (post) │ └─────────────────────────────────────────┘ + ┌─────────────────────────────────────────┐ + State Repr. │ Absolute ────→ Relative │ + │ RelativeStateProcessorStep (pre only) │ + └─────────────────────────────────────────┘ + ┌─────────────────────────────────────────┐ Normalization │ Raw ←────→ Normalized │ │ NormalizerProcessorStep (pre) │ @@ -216,6 +227,10 @@ Here is how the different processors compose. Each arrow is a processor step, an A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`. +With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output). + +With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details. + ## References - [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions. diff --git a/docs/source/umi_pi0_relative_ee.mdx b/docs/source/umi_pi0_relative_ee.mdx index 26a7807d3..090c68fd7 100644 --- a/docs/source/umi_pi0_relative_ee.mdx +++ b/docs/source/umi_pi0_relative_ee.mdx @@ -4,16 +4,13 @@ This guide explains how to prepare a UMI-collected dataset for training a pi0 po **What we will do:** -1. How to add `observation.state` to an existing UMI LeRobot dataset. -2. How to train pi0 with `use_relative_actions=True`. -3. How to evaluate the trained policy on a real robot. +1. Recompute dataset statistics for relative actions and state. +2. Train pi0 with `derive_state_from_action=true` (full UMI pipeline). +3. Evaluate the trained policy on a real robot. ## Background -[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need two additions: - -1. **`observation.state`** — the current EE pose the policy conditions on. -2. **Relative action statistics** — so the normalizer sees `(action − state)` distributions. +[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. UMI datasets stored in LeRobot format already contain `action` (absolute EE pose) and wrist-camera images. To train pi0 with relative actions, we need **relative action statistics** — so the normalizer sees `(action − state)` distributions. ### Why relative actions? @@ -25,72 +22,39 @@ relative_action[i] = absolute_action[t + i] − state[t] This is the representation advocated by UMI (Chi et al., 2024). Compared to absolute actions it removes the need for a consistent global coordinate frame, and compared to delta actions (each step relative to the previous) it avoids error accumulation across the chunk. See the [Action Representations](action_representations) guide for a full comparison. -## State-Action Offset +### Full UMI mode: `derive_state_from_action` -UMI SLAM produces a single trajectory of EE poses stored as `action`. We derive `observation.state` from the same trajectory with a configurable offset: +When `derive_state_from_action=true`, pi0 automatically derives `observation.state` from the action column on the fly — no separate state column or dataset conversion step needed. Under the hood: -``` -state[t] = action[t - offset] -``` +- `action_delta_indices` extends to `[-1, 0, 1, ..., 49]` (one extra leading timestep). +- `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk. +- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`. +- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens. -| Offset | `state[t]` | Meaning | -| ------ | ------------- | ---------------------------------------------------------------- | -| 0 | `action[t]` | State and action are the same pose at time t | -| 1 | `action[t-1]` | State is the previous action — where the gripper already arrived | +This single flag implies `use_relative_state=true` and `state_obs_steps=2`. -An offset of 1 is the typical UMI convention: at decision time the "current state" is where the gripper _already is_ (the result of the previous command), and the action is where it should go next. At episode boundaries where `t < offset`, we clamp to `action[0]`. +During **inference**, state comes from the robot (via FK), so `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically. -## Step 1: Add `observation.state` +## Step 1: Recompute Stats -pi0 with `use_relative_actions=True` needs `observation.state` in the dataset to compute `action - state` on the fly. The script in `examples/umi_pi0_relative_ee/convert_umi_dataset.py` adds it. Edit the constants at the top: - -```python -HF_DATASET_ID = "/" - -# Option A: Copy an existing feature as observation.state -STATE_SOURCE_FEATURE = "observation.joints" # or "observation.pose", etc. - -# Option B: Derive from action with offset (set STATE_SOURCE_FEATURE = None) -STATE_SOURCE_FEATURE = None -STATE_ACTION_OFFSET = 1 -``` - -**Choosing the state source:** - -- If your dataset already has a feature in the same space as `action` (e.g. `observation.joints` for joint-space actions, or `observation.pose` for EE-space actions), set `STATE_SOURCE_FEATURE` to copy it. -- If your dataset only has a single trajectory (like raw UMI EE data where action = the EE poses), set `STATE_SOURCE_FEATURE = None` and use `STATE_ACTION_OFFSET` to derive state from the action column with a time offset. - -`observation.state` **must have the same shape as `action`** — the relative conversion computes `action - state` element-wise. - -Then run: - -```bash -python examples/umi_pi0_relative_ee/convert_umi_dataset.py -``` - - - -If your dataset already has `observation.state`, the script exits early — nothing to do. - - - -## Step 2: Recompute Relative Action Stats - -Use the built-in CLI to recompute dataset statistics in relative space: +Use the built-in CLI to recompute dataset statistics for relative actions and derive-state-from-action: ```bash lerobot-edit-dataset \ --repo_id \ --operation.type recompute_stats \ --operation.relative_action true \ + --operation.derive_state_from_action true \ --operation.chunk_size 50 \ --operation.relative_exclude_joints "['gripper']" \ --push_to_hub true ``` +The `derive_state_from_action` flag tells `recompute_stats` to read from the action column (instead of `observation.state`) when computing relative state stats. It automatically enables `relative_state=true` and `state_obs_steps=2`. + The `relative_exclude_joints` parameter specifies joints that stay absolute. Gripper commands are typically binary or continuous open/close and don't benefit from relative encoding. Leave it as `"[]"` to convert all dimensions to relative. -## Step 3: Train +## Step 2: Train No custom training script is needed — standard `lerobot-train` handles everything: @@ -99,19 +63,26 @@ lerobot-train \ --dataset.repo_id=/ \ --policy.type=pi0 \ --policy.pretrained_path=lerobot/pi0_base \ + --policy.derive_state_from_action=true \ --policy.use_relative_actions=true \ --policy.relative_exclude_joints='["gripper"]' ``` +`derive_state_from_action=true` auto-enables `use_relative_state=true` and `state_obs_steps=2`. + Under the hood, the training pipeline: -- Loads relative action stats from the dataset's `meta/stats.json`. -- Configures `RelativeActionsProcessorStep` in the preprocessor (absolute → relative before normalization). -- The model trains on normalized relative action values. +- Loads relative action stats and relative state stats from the dataset's `meta/stats.json`. +- Extends `action_delta_indices` to `[-1, 0, 1, ..., 49]` to load one extra leading timestep. +- `DeriveStateFromActionStep` extracts the 2-step state from the action chunk and strips the extra timestep. +- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`. +- `RelativeStateProcessorStep` converts the 2-step state to relative offsets from the current timestep, then flattens. +- `NormalizerProcessorStep` normalizes everything. +- The model trains on normalized relative values. See the [pi0 documentation](pi0) for all available training options. -## Step 4: Evaluate +## Step 3: Evaluate The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real robot (SO-100 with EE space): @@ -121,10 +92,20 @@ python examples/umi_pi0_relative_ee/evaluate.py Edit `HF_MODEL_ID`, `HF_DATASET_ID`, and robot configuration at the top of the file. +### Latency compensation + +For real robot deployment, you may want to skip the first few steps of each predicted action chunk to compensate for system latency. Set `LATENCY_SKIP_STEPS` in the evaluate script: + +```python +LATENCY_SKIP_STEPS = 0 # ceil(total_latency_ms / (1000 / FPS)) +``` + +For example, at 10Hz with ~200ms total latency, set `LATENCY_SKIP_STEPS = 2`. + The inference flow uses pi0's built-in processor pipeline — no custom wrappers needed: 1. **Robot → FK** — Joint positions are converted to EE pose via `ForwardKinematicsJointsToEE`, producing `observation.state`. -2. **Preprocessor** — `RelativeActionsProcessorStep` caches the raw `observation.state`, then `NormalizerProcessorStep` normalizes everything. +2. **Preprocessor** — `DeriveStateFromActionStep` is a no-op (state comes from robot). `RelativeStateProcessorStep` buffers previous state, stacks, and converts to relative. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes. 3. **pi0 inference** — The model predicts a normalized relative action chunk. 4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, then `AbsoluteActionsProcessorStep` adds the cached state back to get absolute EE targets. 5. **IK → Robot** — `InverseKinematicsEEToJoints` converts absolute EE targets to joint commands. @@ -132,14 +113,22 @@ The inference flow uses pi0's built-in processor pipeline — no custom wrappers ## How the Pieces Fit Together ``` -Training: - dataset (absolute EE) → RelativeActionsProcessorStep → NormalizerProcessorStep → pi0 model +Training (full UMI mode: derive_state_from_action=true): + DataLoader (action: B,51,D) + → DeriveStateFromActionStep (state = action[:,:2,:], action = action[:,1:,:]) + → RelativeActionsProcessorStep (action -= state[:,-1,:]) + → RelativeStateProcessorStep (state offsets from current, flatten → B,2*D) + → NormalizerProcessorStep → pi0 model Inference: robot joints → FK → observation.state (absolute EE) ↓ + DeriveStateFromActionStep (no-op) + ↓ RelativeActionsProcessorStep (caches state) ↓ + RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens) + ↓ NormalizerProcessorStep → pi0 model → relative action chunk ↓ UnnormalizerProcessorStep @@ -149,6 +138,31 @@ Inference: IK → joint targets → robot ``` +## Manual mode (without derive_state_from_action) + +If your dataset already has `observation.state` (or you want to add it separately), you can skip `derive_state_from_action` and use relative actions + relative state independently: + +```bash +# Recompute stats +lerobot-edit-dataset \ + --repo_id \ + --operation.type recompute_stats \ + --operation.relative_action true \ + --operation.relative_state true \ + --operation.state_obs_steps 2 \ + --operation.chunk_size 50 \ + --operation.relative_exclude_joints "['gripper']" + +# Train +lerobot-train \ + --dataset.repo_id= \ + --policy.type=pi0 \ + --policy.use_relative_actions=true \ + --policy.use_relative_state=true \ + --policy.state_obs_steps=2 \ + --policy.relative_exclude_joints='["gripper"]' +``` + ## References - [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions. diff --git a/examples/umi_pi0_relative_ee/convert_umi_dataset.py b/examples/umi_pi0_relative_ee/convert_umi_dataset.py deleted file mode 100644 index ca14aea62..000000000 --- a/examples/umi_pi0_relative_ee/convert_umi_dataset.py +++ /dev/null @@ -1,220 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Add ``observation.state`` to an existing LeRobot dataset. - -pi0 uses ``observation.state`` as its proprioceptive input AND for -relative action conversion (action − state). This script creates -``observation.state`` by concatenating one or more existing features. - -Ordering matters: the features whose dimensions correspond to ``action`` -must come FIRST, because ``RelativeActionsProcessorStep`` subtracts -``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE -pose) are appended after and are seen by the model but not used for -relative conversion. - -Example: action = [proximal, distal], and we want the model to also see -the EE pose: - - STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"] - → observation.state = [proximal, distal, x, y, z, ax, ay, az] - -The relative conversion uses state[:2] = [proximal, distal] to subtract -from action[:2], and the model sees all 8 dimensions. - -After running this script, recompute relative action stats: - - lerobot-edit-dataset \\ - --repo_id \\ - --operation.type recompute_stats \\ - --operation.relative_action true \\ - --operation.chunk_size 50 \\ - --operation.relative_exclude_joints "[]" \\ - --push_to_hub true - -Usage: - python convert_umi_dataset.py -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable - -import numpy as np - -from lerobot.datasets.dataset_tools import add_features -from lerobot.datasets.lerobot_dataset import LeRobotDataset - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -HF_DATASET_ID = "" - -# Output repo ID. Set to None for default "_modified". -OUTPUT_REPO_ID: str | None = None - -# Features to concatenate into observation.state. Order matters: -# action-matching features FIRST, then extra proprioception. -# Set to a single string to copy one feature directly. -STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"] - -# Only used when STATE_SOURCE_FEATURES is None: -# derive state from action with a per-episode offset. -STATE_ACTION_OFFSET = 1 - -# Push the augmented dataset to the Hugging Face Hub. -PUSH_TO_HUB = True - - -def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]: - hf = dataset.hf_dataset - ep = np.array(hf["episode_index"]) - fr = np.array(hf["frame_index"]) - return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))} - - -def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable: - """Concatenate multiple features into observation.state.""" - hf = dataset.hf_dataset - key_to_global = _build_global_index(dataset) - - columns = [hf[feat] for feat in source_features] - - def _get_state(row_dict: dict, ep_idx: int, frame_idx: int): - g = key_to_global[(ep_idx, frame_idx)] - parts = [] - for col in columns: - val = col[g] - if hasattr(val, "tolist"): - flat = val.tolist() - if isinstance(flat, list): - parts.extend(flat) - else: - parts.append(flat) - elif isinstance(val, list): - parts.extend(val) - else: - parts.append(float(val)) - return parts - - return _get_state - - -def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable: - """Derive state from action with a per-episode offset.""" - hf = dataset.hf_dataset - episode_indices = np.array(hf["episode_index"]) - frame_indices = np.array(hf["frame_index"]) - - ep_sorted: dict[int, list[tuple[int, int]]] = {} - for ep_idx in np.unique(episode_indices): - ep_mask = episode_indices == ep_idx - ep_globals = np.where(ep_mask)[0] - ep_frames = frame_indices[ep_globals] - order = np.argsort(ep_frames) - ep_sorted[int(ep_idx)] = [(int(ep_frames[o]), int(ep_globals[o])) for o in order] - - ep_frame_to_local: dict[int, dict[int, int]] = {} - for ep_idx, sorted_list in ep_sorted.items(): - ep_frame_to_local[ep_idx] = {frame: local for local, (frame, _) in enumerate(sorted_list)} - - actions = hf["action"] - - def _get_state(row_dict: dict, ep_idx: int, frame_idx: int): - local_t = ep_frame_to_local[ep_idx][frame_idx] - source_local = max(0, local_t - offset) - _, source_global = ep_sorted[ep_idx][source_local] - return actions[source_global] - - return _get_state - - -def main(): - logger.info(f"Loading dataset {HF_DATASET_ID}") - dataset = LeRobotDataset(HF_DATASET_ID) - - if "observation.state" in dataset.features: - logger.info("observation.state already exists — nothing to do") - return - - action_meta = dataset.features["action"] - logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}") - - if STATE_SOURCE_FEATURES is not None: - source_list = ( - [STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES) - ) - for feat in source_list: - if feat not in dataset.features: - raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}") - - # Compute combined shape and names - total_dim = 0 - all_names = [] - for feat in source_list: - meta = dataset.features[feat] - total_dim += meta["shape"][0] - names = meta.get("names") - if names: - all_names.extend(names) - - logger.info( - f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})" - ) - state_fn = _build_state_from_features(dataset, source_list) - state_feature_info = { - "dtype": "float32", - "shape": [total_dim], - "names": all_names or None, - } - else: - logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}") - state_fn = _build_state_from_action_offset(dataset, offset=STATE_ACTION_OFFSET) - state_feature_info = { - "dtype": "float32", - "shape": list(action_meta["shape"]), - "names": action_meta.get("names"), - } - - augmented = add_features( - dataset, - features={"observation.state": (state_fn, state_feature_info)}, - repo_id=OUTPUT_REPO_ID, - ) - logger.info("observation.state added") - - if PUSH_TO_HUB: - logger.info(f"Pushing to Hub: {augmented.repo_id}") - augmented.push_to_hub() - - logger.info( - f"Done. Dataset at: {augmented.root}\n" - "Now recompute relative action stats:\n" - " lerobot-edit-dataset \\\n" - f" --repo_id {augmented.repo_id} \\\n" - " --operation.type recompute_stats \\\n" - " --operation.relative_action true \\\n" - " --operation.chunk_size 50 \\\n" - ' --operation.relative_exclude_joints "[]" \\\n' - " --push_to_hub true" - ) - - -if __name__ == "__main__": - main() diff --git a/examples/umi_pi0_relative_ee/evaluate.py b/examples/umi_pi0_relative_ee/evaluate.py index a2906fd48..d77e3213d 100644 --- a/examples/umi_pi0_relative_ee/evaluate.py +++ b/examples/umi_pi0_relative_ee/evaluate.py @@ -17,17 +17,20 @@ """ Inference script for a pi0 model trained with **relative EE actions**. -This uses the built-in ``RelativeActionsProcessorStep`` and -``AbsoluteActionsProcessorStep`` that are already wired into pi0's -processor pipeline when ``use_relative_actions=True``. +This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference), +``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and +``RelativeStateProcessorStep`` that are already wired into pi0's processor +pipeline. The inference loop: 1. Reads joint positions from the robot. 2. Converts to EE pose via forward kinematics (FK). This produces ``observation.state`` with the current EE pose. 3. The pi0 preprocessor: - a) ``RelativeActionsProcessorStep`` caches the raw state. - b) ``NormalizerProcessorStep`` normalizes state and actions. + a) ``DeriveStateFromActionStep`` — no-op (state comes from robot). + b) ``RelativeActionsProcessorStep`` caches the raw state. + c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts. + d) ``NormalizerProcessorStep`` normalizes state and actions. 4. pi0 predicts relative action chunk. 5. The pi0 postprocessor: a) ``UnnormalizerProcessorStep`` unnormalizes. @@ -51,6 +54,7 @@ from lerobot.model.kinematics import RobotKinematics from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.processor import ( + RelativeStateProcessorStep, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -79,6 +83,11 @@ TASK_DESCRIPTION = "manipulation task" HF_MODEL_ID = "/" HF_DATASET_ID = "/" +# Latency compensation: skip this many steps from the start of each predicted +# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)). +# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2. +LATENCY_SKIP_STEPS = 0 + # EE feature keys produced by ForwardKinematicsJointsToEE EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"] @@ -94,6 +103,7 @@ def main(): robot = SO100Follower(robot_config) policy = PI0Policy.from_pretrained(HF_MODEL_ID) + policy.config.latency_skip_steps = LATENCY_SKIP_STEPS kinematics_solver = RobotKinematics( urdf_path="./SO101/so101_new_calib.urdf", @@ -151,9 +161,8 @@ def main(): # Build pre/post processors from the trained model. # The pi0 processor pipeline already includes: - # pre: ... → RelativeActionsProcessorStep → NormalizerProcessorStep + # pre: ... → RelativeStateProcessorStep → RelativeActionsProcessorStep → NormalizerProcessorStep # post: UnnormalizerProcessorStep → AbsoluteActionsProcessorStep → ... - # These handle the relative ↔ absolute conversion automatically. preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy, pretrained_path=HF_MODEL_ID, @@ -161,6 +170,9 @@ def main(): preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, ) + # Find the relative state step (if present) so we can reset its buffer between episodes. + _relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)] + robot.connect() listener, events = init_keyboard_listener() @@ -174,6 +186,10 @@ def main(): for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + # Reset relative state buffer so velocity is zero at episode start + for step in _relative_state_steps: + step.reset() + record_loop( robot=robot, events=events, diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index ce567b8f5..2891b3afc 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -115,6 +115,17 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation raise NotImplementedError + @property + def state_delta_indices(self) -> list | None: # type: ignore[type-arg] + """Delta indices specifically for observation.state. + + When not None, overrides ``observation_delta_indices`` for the + ``observation.state`` key only. Useful for loading state history + (e.g. ``[-1, 0]`` for UMI-style relative proprioception) without + also loading multiple image timesteps. + """ + return None + @abc.abstractmethod def get_optimizer_preset(self) -> OptimizerConfig: raise NotImplementedError diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 03eefe40e..a90e3ed8a 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -767,3 +767,94 @@ def compute_relative_action_stats( ) return stats + + +def compute_relative_state_stats( + hf_dataset, + features: dict, + state_obs_steps: int = 2, + exclude_joints: list[str] | None = None, + source_key: str = OBS_STATE, +) -> dict[str, np.ndarray]: + """Compute normalization statistics for observation.state after relative conversion. + + For UMI-style relative proprioception with ``state_obs_steps`` timesteps, + each state observation becomes a stack of offsets from the current timestep: + ``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``. + + The stats are computed over the flattened ``[state_obs_steps * state_dim]`` + vector that the model actually sees after ``prepare_state`` flattening. + + Args: + hf_dataset: The HuggingFace dataset with the source column and + "episode_index" columns. + features: Dataset feature metadata. + state_obs_steps: Number of observation timesteps (must be >= 2). + exclude_joints: State dimension names to keep absolute. + source_key: Column to read data from. Defaults to "observation.state". + When ``derive_state_from_action=True``, pass ``ACTION`` to read + from the action column instead. + + Returns: + Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99". + """ + from lerobot.processor.relative_action_processor import RelativeStateProcessorStep + + if exclude_joints is None: + exclude_joints = [] + + state_dim = features[source_key]["shape"][0] + state_names = features.get(source_key, {}).get("names") + mask_step = RelativeStateProcessorStep( + enabled=True, + exclude_joints=exclude_joints, + state_names=state_names, + ) + relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32) + + logging.info(f"Loading data from '{source_key}' for relative state stats...") + all_states = np.array(hf_dataset[source_key], dtype=np.float32) + episode_indices = np.array(hf_dataset["episode_index"]) + + # Build all valid windows of length state_obs_steps within each episode + n = len(all_states) + if n < state_obs_steps: + raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}") + + max_start = n - state_obs_steps + starts = np.arange(max_start + 1) + valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1] + valid_starts = starts[valid] + + if len(valid_starts) == 0: + raise RuntimeError("No valid state windows found within single episodes") + + offsets = np.arange(state_obs_steps) + mask_dim = len(relative_mask) + + running_stats = RunningQuantileStats() + + batch_size = 50_000 + for i in range(0, len(valid_starts), batch_size): + batch_starts = valid_starts[i : i + batch_size] + frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps] + windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim] + + # Subtract current (last) timestep from all timesteps for masked dims + current = windows[:, -1:, :] # [N, 1, state_dim] + windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :] + + # Flatten to [N, state_obs_steps * state_dim] (same as prepare_state) + flattened = windows.reshape(len(batch_starts), -1) + running_stats.update(flattened) + + stats = running_stats.get_statistics() + + excluded_dims = int(mask_dim - relative_mask.sum()) + logging.info( + f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): " + f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), " + f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}" + ) + + return stats diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 16bf24822..79dec0204 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -41,6 +41,7 @@ from lerobot.datasets.compute_stats import ( aggregate_stats, compute_episode_stats, compute_relative_action_stats, + compute_relative_state_stats, ) from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.io_utils import ( @@ -1544,6 +1545,10 @@ def recompute_stats( relative_exclude_joints: list[str] | None = None, chunk_size: int = 50, num_workers: int = 0, + relative_state: bool = False, + relative_exclude_state_joints: list[str] | None = None, + state_obs_steps: int = 2, + derive_state_from_action: bool = False, ) -> LeRobotDataset: """Recompute stats.json from scratch by iterating all episodes. @@ -1561,10 +1566,22 @@ def recompute_stats( ``policy.chunk_size``. Only used when ``relative_action=True``. num_workers: Number of parallel threads for relative action stats computation. Values ≤1 mean single-threaded. Only used when ``relative_action=True``. + relative_state: If True, compute observation.state stats in relative space + (multi-timestep offsets from current). This matches the normalization + the model sees during training with ``use_relative_state=True``. + relative_exclude_state_joints: State dim names to exclude from relative conversion. + state_obs_steps: Number of observation timesteps for relative state stats. + Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``. + derive_state_from_action: If True, compute relative state stats from the + action column instead of observation.state. Implies ``relative_state=True`` + and ``state_obs_steps=2``. Returns: The same dataset with updated stats. """ + if derive_state_from_action: + relative_state = True + state_obs_steps = 2 features = dataset.meta.features meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"} numeric_features = { @@ -1596,6 +1613,20 @@ def recompute_stats( ) features_to_compute.pop(ACTION, None) + # When relative_state is enabled, compute state stats over the flattened + # multi-timestep relative representation (matching what the model sees). + relative_state_stats = None + if relative_state and (OBS_STATE in features or derive_state_from_action): + source_key = ACTION if derive_state_from_action else OBS_STATE + relative_state_stats = compute_relative_state_stats( + hf_dataset=dataset.hf_dataset, + features=features, + state_obs_steps=state_obs_steps, + exclude_joints=relative_exclude_state_joints, + source_key=source_key, + ) + features_to_compute.pop(OBS_STATE, None) + logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}") data_dir = dataset.root / DATA_DIR @@ -1632,6 +1663,9 @@ def recompute_stats( if relative_action_stats is not None: new_stats[ACTION] = relative_action_stats + if relative_state_stats is not None: + new_stats[OBS_STATE] = relative_state_stats + # Merge: keep existing stats for features we didn't recompute if dataset.meta.stats: for key, value in dataset.meta.stats.items(): diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 76ece8961..28b1568f9 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD +from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -52,12 +52,15 @@ def resolve_delta_timestamps( returns `None` if the resulting dict is empty. """ delta_timestamps = {} + state_delta = getattr(cfg, "state_delta_indices", None) for key in ds_meta.features: if key == REWARD and cfg.reward_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == ACTION and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] - if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: + if key == OBS_STATE and state_delta is not None: + delta_timestamps[key] = [i / ds_meta.fps for i in state_delta] + elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] if len(delta_timestamps) == 0: diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index cf4b636a3..c18d5e60a 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -57,6 +57,28 @@ class PI0Config(PreTrainedConfig): # Populated at runtime from dataset metadata by make_policy. action_feature_names: list[str] | None = None + # Relative state (UMI-style relative proprioception): converts multi-timestep + # observation.state to offsets from the current timestep, providing velocity info. + # Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to + # max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim. + use_relative_state: bool = False + state_obs_steps: int = 1 + relative_exclude_state_joints: list[str] = field(default_factory=list) + # Populated at runtime from dataset metadata by make_policy. + state_feature_names: list[str] | None = None + + # Derive observation.state from the action column (UMI-style). + # When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1], + # DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state, + # and strips the extra timestep from the action chunk. + # Implies use_relative_state=True and state_obs_steps=2. + derive_state_from_action: bool = False + + # Latency compensation: skip this many steps from the start of each predicted + # action chunk during inference. E.g. at 10Hz with ~200ms total latency, + # latency_skip_steps=2 compensates for the delay. + latency_skip_steps: int = 0 + # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None @@ -106,6 +128,10 @@ class PI0Config(PreTrainedConfig): def __post_init__(self): super().__post_init__() + if self.derive_state_from_action: + self.use_relative_state = True + self.state_obs_steps = 2 + # Validate configuration if self.n_action_steps > self.chunk_size: raise ValueError( @@ -121,6 +147,13 @@ class PI0Config(PreTrainedConfig): if self.dtype not in ["bfloat16", "float32"]: raise ValueError(f"Invalid dtype: {self.dtype}") + if self.use_relative_state and self.state_obs_steps < 2: + raise ValueError( + "use_relative_state requires state_obs_steps >= 2 " + f"(got {self.state_obs_steps}). Set state_obs_steps=2 for " + "UMI-style relative proprioception." + ) + def validate_features(self) -> None: """Validate and set up input/output features.""" for i in range(self.empty_cameras): @@ -166,8 +199,16 @@ class PI0Config(PreTrainedConfig): def observation_delta_indices(self) -> None: return None + @property + def state_delta_indices(self) -> list[int] | None: + if self.state_obs_steps >= 2: + return list(range(-(self.state_obs_steps - 1), 1)) + return None + @property def action_delta_indices(self) -> list: + if self.derive_state_from_action: + return [-1] + list(range(self.chunk_size)) return list(range(self.chunk_size)) @property diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index aebf32964..e63fac360 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -1230,8 +1230,11 @@ class PI0Policy(PreTrainedPolicy): return images, img_masks def prepare_state(self, batch): - """Pad state""" - state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) + """Flatten multi-timestep state and pad to max_state_dim.""" + state = batch[OBS_STATE] + if state.ndim == 3: + state = state.flatten(start_dim=1) + state = pad_vector(state, self.config.max_state_dim) return state def prepare_action(self, batch): @@ -1250,7 +1253,8 @@ class PI0Policy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + skip = self.config.latency_skip_steps + actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 0302876a1..d100577b3 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -24,6 +24,7 @@ from lerobot.processor import ( AbsoluteActionsProcessorStep, AddBatchDimensionProcessorStep, ComplementaryDataProcessorStep, + DeriveStateFromActionStep, DeviceProcessorStep, NormalizerProcessorStep, PolicyAction, @@ -31,6 +32,7 @@ from lerobot.processor import ( ProcessorStep, ProcessorStepRegistry, RelativeActionsProcessorStep, + RelativeStateProcessorStep, RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, @@ -128,13 +130,25 @@ def make_pi0_pre_post_processors( A tuple containing the configured pre-processor and post-processor pipelines. """ + derive_state_step = DeriveStateFromActionStep( + enabled=getattr(config, "derive_state_from_action", False), + ) + relative_step = RelativeActionsProcessorStep( enabled=config.use_relative_actions, exclude_joints=getattr(config, "relative_exclude_joints", []), action_names=getattr(config, "action_feature_names", None), ) - # OpenPI order: raw → relative → normalize → model → unnormalize → absolute + relative_state_step = RelativeStateProcessorStep( + enabled=getattr(config, "use_relative_state", False), + exclude_joints=getattr(config, "relative_exclude_state_joints", []), + state_names=getattr(config, "state_feature_names", None), + ) + + # Order: DeriveStateFromAction extracts state from the extended action chunk, + # then relative_action uses current state[t] for subtraction, + # then relative_state converts the multi-timestep state to offsets. input_steps: list[ProcessorStep] = [ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one AddBatchDimensionProcessorStep(), @@ -146,7 +160,9 @@ def make_pi0_pre_post_processors( padding="max_length", ), DeviceProcessorStep(device=config.device), + derive_state_step, relative_step, + relative_state_step, NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 122b3533c..584f63a93 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -77,9 +77,12 @@ from .policy_robot_bridge import ( ) from .relative_action_processor import ( AbsoluteActionsProcessorStep, + DeriveStateFromActionStep, RelativeActionsProcessorStep, + RelativeStateProcessorStep, to_absolute_actions, to_relative_actions, + to_relative_state, ) from .rename_processor import RenameObservationsProcessorStep from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep @@ -107,7 +110,9 @@ __all__ = [ "make_default_robot_action_processor", "make_default_robot_observation_processor", "AbsoluteActionsProcessorStep", + "DeriveStateFromActionStep", "RelativeActionsProcessorStep", + "RelativeStateProcessorStep", "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", "NormalizerProcessorStep", @@ -139,6 +144,7 @@ __all__ = [ "TruncatedProcessorStep", "to_absolute_actions", "to_relative_actions", + "to_relative_state", "UnnormalizerProcessorStep", "VanillaObservationProcessorStep", ] diff --git a/src/lerobot/processor/relative_action_processor.py b/src/lerobot/processor/relative_action_processor.py index e00d26e98..1076b71d4 100644 --- a/src/lerobot/processor/relative_action_processor.py +++ b/src/lerobot/processor/relative_action_processor.py @@ -30,10 +30,13 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry __all__ = [ "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", + "DeriveStateFromActionStep", "RelativeActionsProcessorStep", "AbsoluteActionsProcessorStep", + "RelativeStateProcessorStep", "to_relative_actions", "to_absolute_actions", + "to_relative_state", ] @@ -81,6 +84,41 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> return actions +@ProcessorStepRegistry.register("derive_state_from_action_processor") +@dataclass +class DeriveStateFromActionStep(ProcessorStep): + """Derives 2-step observation.state from the action chunk (UMI-style). + + Expects action with one extra leading timestep: [B, chunk_size+1, D] + from action_delta_indices = [-1, 0, 1, ..., chunk_size-1]. + Extracts [action[t-1], action[t]] as state and strips the extra timestep. + No-op during inference (state comes from robot). + """ + + enabled: bool = False + + def __call__(self, transition: EnvTransition) -> EnvTransition: + if not self.enabled: + return transition + action = transition.get(TransitionKey.ACTION) + if action is None or action.ndim < 3: + return transition + new_transition = transition.copy() + new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {})) + new_obs[OBS_STATE] = action[..., :2, :] + new_transition[TransitionKey.ACTION] = action[..., 1:, :] + new_transition[TransitionKey.OBSERVATION] = new_obs + return new_transition + + def get_config(self) -> dict[str, Any]: + return {"enabled": self.enabled} + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + @ProcessorStepRegistry.register("delta_actions_processor") @dataclass class RelativeActionsProcessorStep(ProcessorStep): @@ -124,7 +162,14 @@ class RelativeActionsProcessorStep(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: observation = transition.get(TransitionKey.OBSERVATION, {}) - state = observation.get(OBS_STATE) if observation else None + raw_state = observation.get(OBS_STATE) if observation else None + + # When state_delta_indices loads multi-timestep state [B, n_obs, D], + # use only the current (last) timestep for relative action conversion. + if raw_state is not None: + state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state + else: + state = None # Always cache state for the paired AbsoluteActionsProcessorStep if state is not None: @@ -155,6 +200,120 @@ class RelativeActionsProcessorStep(ProcessorStep): return features +def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor: + """Convert multi-timestep absolute state to relative (offset from current timestep). + + Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims. + The last (current) timestep becomes zeros for masked dims. + + Args: + state: (..., n_obs, state_dim) — last timestep is the reference (current). + mask: Which dims to convert. Can be shorter than state_dim. + """ + mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device) + dims = mask_t.shape[0] + current = state[..., -1:, :] # (..., 1, state_dim) + state = state.clone() + state[..., :dims] -= current[..., :dims] * mask_t + return state + + +@ProcessorStepRegistry.register("relative_state_processor") +@dataclass +class RelativeStateProcessorStep(ProcessorStep): + """Converts observation.state to relative (offset from current timestep). + + UMI-style relative proprioception: each state timestep is expressed as + an offset from the current EE pose, providing velocity information. + + During training (multi-timestep input from ``state_delta_indices``): + ``state[..., t, :] -= state[..., -1, :]`` — subtract current from all. + + During inference (single timestep): buffers the previous state and stacks + ``[previous, current]`` before applying the relative conversion, producing + the same ``[n_obs, D]`` shape the model expects. + + Attributes: + enabled: Whether to apply the relative conversion. + exclude_joints: Joint/dim names to keep absolute. + state_names: State dimension names from dataset metadata. + """ + + enabled: bool = False + exclude_joints: list[str] = field(default_factory=list) + state_names: list[str] | None = None + _previous_state: torch.Tensor | None = field(default=None, init=False, repr=False) + + def _build_mask(self, state_dim: int) -> list[bool]: + if not self.exclude_joints or self.state_names is None: + return [True] * state_dim + + exclude_tokens = [str(name).lower() for name in self.exclude_joints if name] + if not exclude_tokens: + return [True] * state_dim + + mask = [] + for name in self.state_names[:state_dim]: + state_name = str(name).lower() + is_excluded = any(token == state_name or token in state_name for token in exclude_tokens) + mask.append(not is_excluded) + + if len(mask) < state_dim: + mask.extend([True] * (state_dim - len(mask))) + + return mask + + def __call__(self, transition: EnvTransition) -> EnvTransition: + if not self.enabled: + return transition + + observation = transition.get(TransitionKey.OBSERVATION, {}) + state = observation.get(OBS_STATE) if observation else None + + if state is None: + return transition + + new_transition = transition.copy() + new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {})) + mask = self._build_mask(state.shape[-1]) + + if state.ndim >= 3: + # [B, n_obs, D] — multi-timestep (training with state_delta_indices) + relative = to_relative_state(state, mask) + new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D] + elif state.ndim == 2: + # [B, D] — single timestep (inference): buffer previous and stack + current = state + if self._previous_state is None: + self._previous_state = current.clone() + prev = self._previous_state + if prev.device != current.device or prev.dtype != current.dtype: + prev = prev.to(device=current.device, dtype=current.dtype) + stacked = torch.stack([prev, current], dim=-2) # [B, 2, D] + relative = to_relative_state(stacked, mask) + new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D] + self._previous_state = current.clone() + + new_transition[TransitionKey.OBSERVATION] = new_obs + return new_transition + + def reset(self) -> None: + """Reset the state buffer. Call at episode boundaries during inference.""" + self._previous_state = None + + def get_config(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "exclude_joints": self.exclude_joints, + "state_names": self.state_names, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + @ProcessorStepRegistry.register("absolute_actions_processor") @dataclass class AbsoluteActionsProcessorStep(ProcessorStep): diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index db06f90c6..9fd64ed68 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -254,6 +254,10 @@ class RecomputeStatsConfig(OperationConfig): relative_exclude_joints: list[str] | None = None chunk_size: int = 50 num_workers: int = 0 + relative_state: bool = False + relative_exclude_state_joints: list[str] | None = None + state_obs_steps: int = 2 + derive_state_from_action: bool = False @OperationConfig.register_subclass("info") @@ -563,6 +567,14 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None: f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, " f"exclude_joints={cfg.operation.relative_exclude_joints})" ) + if cfg.operation.relative_state: + logging.info( + f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, " + f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})" + ) + + if cfg.operation.derive_state_from_action: + logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)") recompute_stats( dataset, @@ -571,6 +583,10 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None: relative_exclude_joints=cfg.operation.relative_exclude_joints, chunk_size=cfg.operation.chunk_size, num_workers=cfg.operation.num_workers, + relative_state=cfg.operation.relative_state, + relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints, + state_obs_steps=cfg.operation.state_obs_steps, + derive_state_from_action=cfg.operation.derive_state_from_action, ) logging.info(f"Stats written to {dataset.root}")