This commit is contained in:
Pepijn
2026-04-01 15:29:59 +02:00
parent dfe16e8b84
commit 0fc855df13
@@ -17,28 +17,33 @@
""" """
Add ``observation.state`` to an existing LeRobot dataset. Add ``observation.state`` to an existing LeRobot dataset.
pi0 with ``use_relative_actions=True`` requires ``observation.state`` to pi0 uses ``observation.state`` as its proprioceptive input AND for
compute relative actions (action state) on the fly. This script adds relative action conversion (action state). This script creates
that feature when it doesn't already exist. ``observation.state`` by concatenating one or more existing features.
Two modes for deriving ``observation.state``: Ordering matters: the features whose dimensions correspond to ``action``
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
pose) are appended after and are seen by the model but not used for
relative conversion.
1. **From an existing feature** (``STATE_SOURCE_FEATURE``): Example: action = [proximal, distal], and we want the model to also see
Copies an existing column (e.g. ``observation.joints`` or the EE pose:
``observation.pose``) to ``observation.state``.
2. **From action with offset** (``STATE_SOURCE_FEATURE = None``): STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
Derives state from the action column with a per-episode offset: → observation.state = [proximal, distal, x, y, z, ax, ay, az]
state[t] = action[t - STATE_ACTION_OFFSET]
After running this script, recompute relative action stats via the CLI: The relative conversion uses state[:2] = [proximal, distal] to subtract
from action[:2], and the model sees all 8 dimensions.
After running this script, recompute relative action stats:
lerobot-edit-dataset \\ lerobot-edit-dataset \\
--repo_id <your_dataset> \\ --repo_id <your_dataset> \\
--operation.type recompute_stats \\ --operation.type recompute_stats \\
--operation.relative_action true \\ --operation.relative_action true \\
--operation.chunk_size 50 \\ --operation.chunk_size 50 \\
--operation.relative_exclude_joints "['gripper']" \\ --operation.relative_exclude_joints "[]" \\
--push_to_hub true --push_to_hub true
Usage: Usage:
@@ -61,41 +66,58 @@ logger = logging.getLogger(__name__)
HF_DATASET_ID = "" HF_DATASET_ID = ""
# Source for observation.state. Options: # Output repo ID. Set to None for default "<input>_modified".
# - A feature name (e.g. "observation.joints", "observation.pose") to copy OUTPUT_REPO_ID: str | None = None
# an existing column. Must have the same shape as "action".
# - None to derive state from action with STATE_ACTION_OFFSET.
STATE_SOURCE_FEATURE: str | None = "observation.joints"
# Only used when STATE_SOURCE_FEATURE is None. # Features to concatenate into observation.state. Order matters:
# 0 → state[t] = action[t] (same instant) # action-matching features FIRST, then extra proprioception.
# 1 → state[t] = action[t-1] (state lags by 1 step) # Set to a single string to copy one feature directly.
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
# Only used when STATE_SOURCE_FEATURES is None:
# derive state from action with a per-episode offset.
STATE_ACTION_OFFSET = 1 STATE_ACTION_OFFSET = 1
# Push the augmented dataset to the Hugging Face Hub. # Push the augmented dataset to the Hugging Face Hub.
PUSH_TO_HUB = True PUSH_TO_HUB = True
def _build_state_from_feature(dataset: LeRobotDataset, source_feature: str) -> Callable: def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
"""Return a callable that copies values from an existing feature."""
hf = dataset.hf_dataset hf = dataset.hf_dataset
source_values = hf[source_feature] ep = np.array(hf["episode_index"])
fr = np.array(hf["frame_index"])
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"]) def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
key_to_global = {(int(episode_indices[i]), int(frame_indices[i])): i for i in range(len(episode_indices))} """Concatenate multiple features into observation.state."""
hf = dataset.hf_dataset
key_to_global = _build_global_index(dataset)
columns = [hf[feat] for feat in source_features]
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int): def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
return source_values[key_to_global[(ep_idx, frame_idx)]] g = key_to_global[(ep_idx, frame_idx)]
parts = []
for col in columns:
val = col[g]
if hasattr(val, "tolist"):
flat = val.tolist()
if isinstance(flat, list):
parts.extend(flat)
else:
parts.append(flat)
elif isinstance(val, list):
parts.extend(val)
else:
parts.append(float(val))
return parts
return _get_state return _get_state
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable: def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
"""Return a callable that derives state from action with a per-episode offset. """Derive state from action with a per-episode offset."""
state[t] = action[max(0, t - offset)] (clamped to episode start)
"""
hf = dataset.hf_dataset hf = dataset.hf_dataset
episode_indices = np.array(hf["episode_index"]) episode_indices = np.array(hf["episode_index"])
frame_indices = np.array(hf["frame_index"]) frame_indices = np.array(hf["frame_index"])
@@ -134,19 +156,32 @@ def main():
action_meta = dataset.features["action"] action_meta = dataset.features["action"]
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}") logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
if STATE_SOURCE_FEATURE is not None: if STATE_SOURCE_FEATURES is not None:
if STATE_SOURCE_FEATURE not in dataset.features: source_list = (
raise ValueError( [STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
f"Source feature '{STATE_SOURCE_FEATURE}' not found. " )
f"Available: {list(dataset.features.keys())}" for feat in source_list:
) if feat not in dataset.features:
source_meta = dataset.features[STATE_SOURCE_FEATURE] raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
logger.info(f"Copying {STATE_SOURCE_FEATURE} → observation.state")
state_fn = _build_state_from_feature(dataset, STATE_SOURCE_FEATURE) # Compute combined shape and names
total_dim = 0
all_names = []
for feat in source_list:
meta = dataset.features[feat]
total_dim += meta["shape"][0]
names = meta.get("names")
if names:
all_names.extend(names)
logger.info(
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
)
state_fn = _build_state_from_features(dataset, source_list)
state_feature_info = { state_feature_info = {
"dtype": "float32", "dtype": "float32",
"shape": list(source_meta["shape"]), "shape": [total_dim],
"names": source_meta.get("names"), "names": all_names or None,
} }
else: else:
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}") logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
@@ -160,6 +195,7 @@ def main():
augmented = add_features( augmented = add_features(
dataset, dataset,
features={"observation.state": (state_fn, state_feature_info)}, features={"observation.state": (state_fn, state_feature_info)},
repo_id=OUTPUT_REPO_ID,
) )
logger.info("observation.state added") logger.info("observation.state added")
@@ -168,13 +204,14 @@ def main():
augmented.push_to_hub() augmented.push_to_hub()
logger.info( logger.info(
f"Done. Now recompute relative action stats:\n" f"Done. Dataset at: {augmented.root}\n"
"Now recompute relative action stats:\n"
" lerobot-edit-dataset \\\n" " lerobot-edit-dataset \\\n"
f" --repo_id {augmented.repo_id} \\\n" f" --repo_id {augmented.repo_id} \\\n"
" --operation.type recompute_stats \\\n" " --operation.type recompute_stats \\\n"
" --operation.relative_action true \\\n" " --operation.relative_action true \\\n"
" --operation.chunk_size 50 \\\n" " --operation.chunk_size 50 \\\n"
" --operation.relative_exclude_joints \"['gripper']\" \\\n" ' --operation.relative_exclude_joints "[]" \\\n'
" --push_to_hub true" " --push_to_hub true"
) )