diff --git a/examples/umi_pi0_relative_ee/convert_umi_dataset.py b/examples/umi_pi0_relative_ee/convert_umi_dataset.py index 74e70ea4d..ca14aea62 100644 --- a/examples/umi_pi0_relative_ee/convert_umi_dataset.py +++ b/examples/umi_pi0_relative_ee/convert_umi_dataset.py @@ -17,28 +17,33 @@ """ Add ``observation.state`` to an existing LeRobot dataset. -pi0 with ``use_relative_actions=True`` requires ``observation.state`` to -compute relative actions (action − state) on the fly. This script adds -that feature when it doesn't already exist. +pi0 uses ``observation.state`` as its proprioceptive input AND for +relative action conversion (action − state). This script creates +``observation.state`` by concatenating one or more existing features. -Two modes for deriving ``observation.state``: +Ordering matters: the features whose dimensions correspond to ``action`` +must come FIRST, because ``RelativeActionsProcessorStep`` subtracts +``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE +pose) are appended after and are seen by the model but not used for +relative conversion. - 1. **From an existing feature** (``STATE_SOURCE_FEATURE``): - Copies an existing column (e.g. ``observation.joints`` or - ``observation.pose``) to ``observation.state``. +Example: action = [proximal, distal], and we want the model to also see +the EE pose: - 2. **From action with offset** (``STATE_SOURCE_FEATURE = None``): - Derives state from the action column with a per-episode offset: - state[t] = action[t - STATE_ACTION_OFFSET] + STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"] + → observation.state = [proximal, distal, x, y, z, ax, ay, az] -After running this script, recompute relative action stats via the CLI: +The relative conversion uses state[:2] = [proximal, distal] to subtract +from action[:2], and the model sees all 8 dimensions. + +After running this script, recompute relative action stats: lerobot-edit-dataset \\ --repo_id \\ --operation.type recompute_stats \\ --operation.relative_action true \\ --operation.chunk_size 50 \\ - --operation.relative_exclude_joints "['gripper']" \\ + --operation.relative_exclude_joints "[]" \\ --push_to_hub true Usage: @@ -61,41 +66,58 @@ logger = logging.getLogger(__name__) HF_DATASET_ID = "" -# Source for observation.state. Options: -# - A feature name (e.g. "observation.joints", "observation.pose") to copy -# an existing column. Must have the same shape as "action". -# - None to derive state from action with STATE_ACTION_OFFSET. -STATE_SOURCE_FEATURE: str | None = "observation.joints" +# Output repo ID. Set to None for default "_modified". +OUTPUT_REPO_ID: str | None = None -# Only used when STATE_SOURCE_FEATURE is None. -# 0 → state[t] = action[t] (same instant) -# 1 → state[t] = action[t-1] (state lags by 1 step) +# Features to concatenate into observation.state. Order matters: +# action-matching features FIRST, then extra proprioception. +# Set to a single string to copy one feature directly. +STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"] + +# Only used when STATE_SOURCE_FEATURES is None: +# derive state from action with a per-episode offset. STATE_ACTION_OFFSET = 1 # Push the augmented dataset to the Hugging Face Hub. PUSH_TO_HUB = True -def _build_state_from_feature(dataset: LeRobotDataset, source_feature: str) -> Callable: - """Return a callable that copies values from an existing feature.""" +def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]: hf = dataset.hf_dataset - source_values = hf[source_feature] + ep = np.array(hf["episode_index"]) + fr = np.array(hf["frame_index"]) + return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))} - episode_indices = np.array(hf["episode_index"]) - frame_indices = np.array(hf["frame_index"]) - key_to_global = {(int(episode_indices[i]), int(frame_indices[i])): i for i in range(len(episode_indices))} + +def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable: + """Concatenate multiple features into observation.state.""" + hf = dataset.hf_dataset + key_to_global = _build_global_index(dataset) + + columns = [hf[feat] for feat in source_features] def _get_state(row_dict: dict, ep_idx: int, frame_idx: int): - return source_values[key_to_global[(ep_idx, frame_idx)]] + g = key_to_global[(ep_idx, frame_idx)] + parts = [] + for col in columns: + val = col[g] + if hasattr(val, "tolist"): + flat = val.tolist() + if isinstance(flat, list): + parts.extend(flat) + else: + parts.append(flat) + elif isinstance(val, list): + parts.extend(val) + else: + parts.append(float(val)) + return parts return _get_state def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable: - """Return a callable that derives state from action with a per-episode offset. - - state[t] = action[max(0, t - offset)] (clamped to episode start) - """ + """Derive state from action with a per-episode offset.""" hf = dataset.hf_dataset episode_indices = np.array(hf["episode_index"]) frame_indices = np.array(hf["frame_index"]) @@ -134,19 +156,32 @@ def main(): action_meta = dataset.features["action"] logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}") - if STATE_SOURCE_FEATURE is not None: - if STATE_SOURCE_FEATURE not in dataset.features: - raise ValueError( - f"Source feature '{STATE_SOURCE_FEATURE}' not found. " - f"Available: {list(dataset.features.keys())}" - ) - source_meta = dataset.features[STATE_SOURCE_FEATURE] - logger.info(f"Copying {STATE_SOURCE_FEATURE} → observation.state") - state_fn = _build_state_from_feature(dataset, STATE_SOURCE_FEATURE) + if STATE_SOURCE_FEATURES is not None: + source_list = ( + [STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES) + ) + for feat in source_list: + if feat not in dataset.features: + raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}") + + # Compute combined shape and names + total_dim = 0 + all_names = [] + for feat in source_list: + meta = dataset.features[feat] + total_dim += meta["shape"][0] + names = meta.get("names") + if names: + all_names.extend(names) + + logger.info( + f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})" + ) + state_fn = _build_state_from_features(dataset, source_list) state_feature_info = { "dtype": "float32", - "shape": list(source_meta["shape"]), - "names": source_meta.get("names"), + "shape": [total_dim], + "names": all_names or None, } else: logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}") @@ -160,6 +195,7 @@ def main(): augmented = add_features( dataset, features={"observation.state": (state_fn, state_feature_info)}, + repo_id=OUTPUT_REPO_ID, ) logger.info("observation.state added") @@ -168,13 +204,14 @@ def main(): augmented.push_to_hub() logger.info( - f"Done. Now recompute relative action stats:\n" + f"Done. Dataset at: {augmented.root}\n" + "Now recompute relative action stats:\n" " lerobot-edit-dataset \\\n" f" --repo_id {augmented.repo_id} \\\n" " --operation.type recompute_stats \\\n" " --operation.relative_action true \\\n" " --operation.chunk_size 50 \\\n" - " --operation.relative_exclude_joints \"['gripper']\" \\\n" + ' --operation.relative_exclude_joints "[]" \\\n' " --push_to_hub true" )