This commit is contained in:
Pepijn
2026-04-01 15:29:59 +02:00
parent dfe16e8b84
commit 0fc855df13
@@ -17,28 +17,33 @@
"""
Add ``observation.state`` to an existing LeRobot dataset.
pi0 with ``use_relative_actions=True`` requires ``observation.state`` to
compute relative actions (action state) on the fly. This script adds
that feature when it doesn't already exist.
pi0 uses ``observation.state`` as its proprioceptive input AND for
relative action conversion (action state). This script creates
``observation.state`` by concatenating one or more existing features.
Two modes for deriving ``observation.state``:
Ordering matters: the features whose dimensions correspond to ``action``
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
pose) are appended after and are seen by the model but not used for
relative conversion.
1. **From an existing feature** (``STATE_SOURCE_FEATURE``):
Copies an existing column (e.g. ``observation.joints`` or
``observation.pose``) to ``observation.state``.
Example: action = [proximal, distal], and we want the model to also see
the EE pose:
2. **From action with offset** (``STATE_SOURCE_FEATURE = None``):
Derives state from the action column with a per-episode offset:
state[t] = action[t - STATE_ACTION_OFFSET]
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
After running this script, recompute relative action stats via the CLI:
The relative conversion uses state[:2] = [proximal, distal] to subtract
from action[:2], and the model sees all 8 dimensions.
After running this script, recompute relative action stats:
lerobot-edit-dataset \\
--repo_id <your_dataset> \\
--operation.type recompute_stats \\
--operation.relative_action true \\
--operation.chunk_size 50 \\
--operation.relative_exclude_joints "['gripper']" \\
--operation.relative_exclude_joints "[]" \\
--push_to_hub true
Usage:
@@ -61,41 +66,58 @@ logger = logging.getLogger(__name__)
HF_DATASET_ID = ""
# Source for observation.state. Options:
# - A feature name (e.g. "observation.joints", "observation.pose") to copy
# an existing column. Must have the same shape as "action".
# - None to derive state from action with STATE_ACTION_OFFSET.
STATE_SOURCE_FEATURE: str | None = "observation.joints"
# Output repo ID. Set to None for default "<input>_modified".
OUTPUT_REPO_ID: str | None = None
# Only used when STATE_SOURCE_FEATURE is None.
# 0 → state[t] = action[t] (same instant)
# 1 → state[t] = action[t-1] (state lags by 1 step)
# Features to concatenate into observation.state. Order matters:
# action-matching features FIRST, then extra proprioception.
# Set to a single string to copy one feature directly.
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
# Only used when STATE_SOURCE_FEATURES is None:
# derive state from action with a per-episode offset.
STATE_ACTION_OFFSET = 1
# Push the augmented dataset to the Hugging Face Hub.
PUSH_TO_HUB = True
def _build_state_from_feature(dataset: LeRobotDataset, source_feature: str) -> Callable:
"""Return a callable that copies values from an existing feature."""
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
hf = dataset.hf_dataset
source_values = hf[source_feature]
ep = np.array(hf["episode_index"])
fr = np.array(hf["frame_index"])
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"])
key_to_global = {(int(episode_indices[i]), int(frame_indices[i])): i for i in range(len(episode_indices))}
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
"""Concatenate multiple features into observation.state."""
hf = dataset.hf_dataset
key_to_global = _build_global_index(dataset)
columns = [hf[feat] for feat in source_features]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
return source_values[key_to_global[(ep_idx, frame_idx)]]
g = key_to_global[(ep_idx, frame_idx)]
parts = []
for col in columns:
val = col[g]
if hasattr(val, "tolist"):
flat = val.tolist()
if isinstance(flat, list):
parts.extend(flat)
else:
parts.append(flat)
elif isinstance(val, list):
parts.extend(val)
else:
parts.append(float(val))
return parts
return _get_state
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
"""Return a callable that derives state from action with a per-episode offset.
state[t] = action[max(0, t - offset)] (clamped to episode start)
"""
"""Derive state from action with a per-episode offset."""
hf = dataset.hf_dataset
episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"])
@@ -134,19 +156,32 @@ def main():
action_meta = dataset.features["action"]
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
if STATE_SOURCE_FEATURE is not None:
if STATE_SOURCE_FEATURE not in dataset.features:
raise ValueError(
f"Source feature '{STATE_SOURCE_FEATURE}' not found. "
f"Available: {list(dataset.features.keys())}"
)
source_meta = dataset.features[STATE_SOURCE_FEATURE]
logger.info(f"Copying {STATE_SOURCE_FEATURE} → observation.state")
state_fn = _build_state_from_feature(dataset, STATE_SOURCE_FEATURE)
if STATE_SOURCE_FEATURES is not None:
source_list = (
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
)
for feat in source_list:
if feat not in dataset.features:
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
# Compute combined shape and names
total_dim = 0
all_names = []
for feat in source_list:
meta = dataset.features[feat]
total_dim += meta["shape"][0]
names = meta.get("names")
if names:
all_names.extend(names)
logger.info(
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
)
state_fn = _build_state_from_features(dataset, source_list)
state_feature_info = {
"dtype": "float32",
"shape": list(source_meta["shape"]),
"names": source_meta.get("names"),
"shape": [total_dim],
"names": all_names or None,
}
else:
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
@@ -160,6 +195,7 @@ def main():
augmented = add_features(
dataset,
features={"observation.state": (state_fn, state_feature_info)},
repo_id=OUTPUT_REPO_ID,
)
logger.info("observation.state added")
@@ -168,13 +204,14 @@ def main():
augmented.push_to_hub()
logger.info(
f"Done. Now recompute relative action stats:\n"
f"Done. Dataset at: {augmented.root}\n"
"Now recompute relative action stats:\n"
" lerobot-edit-dataset \\\n"
f" --repo_id {augmented.repo_id} \\\n"
" --operation.type recompute_stats \\\n"
" --operation.relative_action true \\\n"
" --operation.chunk_size 50 \\\n"
" --operation.relative_exclude_joints \"['gripper']\" \\\n"
' --operation.relative_exclude_joints "[]" \\\n'
" --push_to_hub true"
)