mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix
This commit is contained in:
@@ -17,28 +17,33 @@
|
|||||||
"""
|
"""
|
||||||
Add ``observation.state`` to an existing LeRobot dataset.
|
Add ``observation.state`` to an existing LeRobot dataset.
|
||||||
|
|
||||||
pi0 with ``use_relative_actions=True`` requires ``observation.state`` to
|
pi0 uses ``observation.state`` as its proprioceptive input AND for
|
||||||
compute relative actions (action − state) on the fly. This script adds
|
relative action conversion (action − state). This script creates
|
||||||
that feature when it doesn't already exist.
|
``observation.state`` by concatenating one or more existing features.
|
||||||
|
|
||||||
Two modes for deriving ``observation.state``:
|
Ordering matters: the features whose dimensions correspond to ``action``
|
||||||
|
must come FIRST, because ``RelativeActionsProcessorStep`` subtracts
|
||||||
|
``state[:action_dim]`` from the action. Extra state dimensions (e.g. EE
|
||||||
|
pose) are appended after and are seen by the model but not used for
|
||||||
|
relative conversion.
|
||||||
|
|
||||||
1. **From an existing feature** (``STATE_SOURCE_FEATURE``):
|
Example: action = [proximal, distal], and we want the model to also see
|
||||||
Copies an existing column (e.g. ``observation.joints`` or
|
the EE pose:
|
||||||
``observation.pose``) to ``observation.state``.
|
|
||||||
|
|
||||||
2. **From action with offset** (``STATE_SOURCE_FEATURE = None``):
|
STATE_SOURCE_FEATURES = ["observation.joints", "observation.pose"]
|
||||||
Derives state from the action column with a per-episode offset:
|
→ observation.state = [proximal, distal, x, y, z, ax, ay, az]
|
||||||
state[t] = action[t - STATE_ACTION_OFFSET]
|
|
||||||
|
|
||||||
After running this script, recompute relative action stats via the CLI:
|
The relative conversion uses state[:2] = [proximal, distal] to subtract
|
||||||
|
from action[:2], and the model sees all 8 dimensions.
|
||||||
|
|
||||||
|
After running this script, recompute relative action stats:
|
||||||
|
|
||||||
lerobot-edit-dataset \\
|
lerobot-edit-dataset \\
|
||||||
--repo_id <your_dataset> \\
|
--repo_id <your_dataset> \\
|
||||||
--operation.type recompute_stats \\
|
--operation.type recompute_stats \\
|
||||||
--operation.relative_action true \\
|
--operation.relative_action true \\
|
||||||
--operation.chunk_size 50 \\
|
--operation.chunk_size 50 \\
|
||||||
--operation.relative_exclude_joints "['gripper']" \\
|
--operation.relative_exclude_joints "[]" \\
|
||||||
--push_to_hub true
|
--push_to_hub true
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@@ -61,41 +66,58 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
HF_DATASET_ID = ""
|
HF_DATASET_ID = ""
|
||||||
|
|
||||||
# Source for observation.state. Options:
|
# Output repo ID. Set to None for default "<input>_modified".
|
||||||
# - A feature name (e.g. "observation.joints", "observation.pose") to copy
|
OUTPUT_REPO_ID: str | None = None
|
||||||
# an existing column. Must have the same shape as "action".
|
|
||||||
# - None to derive state from action with STATE_ACTION_OFFSET.
|
|
||||||
STATE_SOURCE_FEATURE: str | None = "observation.joints"
|
|
||||||
|
|
||||||
# Only used when STATE_SOURCE_FEATURE is None.
|
# Features to concatenate into observation.state. Order matters:
|
||||||
# 0 → state[t] = action[t] (same instant)
|
# action-matching features FIRST, then extra proprioception.
|
||||||
# 1 → state[t] = action[t-1] (state lags by 1 step)
|
# Set to a single string to copy one feature directly.
|
||||||
|
STATE_SOURCE_FEATURES: list[str] | str = ["observation.joints", "observation.pose"]
|
||||||
|
|
||||||
|
# Only used when STATE_SOURCE_FEATURES is None:
|
||||||
|
# derive state from action with a per-episode offset.
|
||||||
STATE_ACTION_OFFSET = 1
|
STATE_ACTION_OFFSET = 1
|
||||||
|
|
||||||
# Push the augmented dataset to the Hugging Face Hub.
|
# Push the augmented dataset to the Hugging Face Hub.
|
||||||
PUSH_TO_HUB = True
|
PUSH_TO_HUB = True
|
||||||
|
|
||||||
|
|
||||||
def _build_state_from_feature(dataset: LeRobotDataset, source_feature: str) -> Callable:
|
def _build_global_index(dataset: LeRobotDataset) -> dict[tuple[int, int], int]:
|
||||||
"""Return a callable that copies values from an existing feature."""
|
|
||||||
hf = dataset.hf_dataset
|
hf = dataset.hf_dataset
|
||||||
source_values = hf[source_feature]
|
ep = np.array(hf["episode_index"])
|
||||||
|
fr = np.array(hf["frame_index"])
|
||||||
|
return {(int(ep[i]), int(fr[i])): i for i in range(len(ep))}
|
||||||
|
|
||||||
episode_indices = np.array(hf["episode_index"])
|
|
||||||
frame_indices = np.array(hf["frame_index"])
|
def _build_state_from_features(dataset: LeRobotDataset, source_features: list[str]) -> Callable:
|
||||||
key_to_global = {(int(episode_indices[i]), int(frame_indices[i])): i for i in range(len(episode_indices))}
|
"""Concatenate multiple features into observation.state."""
|
||||||
|
hf = dataset.hf_dataset
|
||||||
|
key_to_global = _build_global_index(dataset)
|
||||||
|
|
||||||
|
columns = [hf[feat] for feat in source_features]
|
||||||
|
|
||||||
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
def _get_state(row_dict: dict, ep_idx: int, frame_idx: int):
|
||||||
return source_values[key_to_global[(ep_idx, frame_idx)]]
|
g = key_to_global[(ep_idx, frame_idx)]
|
||||||
|
parts = []
|
||||||
|
for col in columns:
|
||||||
|
val = col[g]
|
||||||
|
if hasattr(val, "tolist"):
|
||||||
|
flat = val.tolist()
|
||||||
|
if isinstance(flat, list):
|
||||||
|
parts.extend(flat)
|
||||||
|
else:
|
||||||
|
parts.append(flat)
|
||||||
|
elif isinstance(val, list):
|
||||||
|
parts.extend(val)
|
||||||
|
else:
|
||||||
|
parts.append(float(val))
|
||||||
|
return parts
|
||||||
|
|
||||||
return _get_state
|
return _get_state
|
||||||
|
|
||||||
|
|
||||||
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
|
def _build_state_from_action_offset(dataset: LeRobotDataset, offset: int) -> Callable:
|
||||||
"""Return a callable that derives state from action with a per-episode offset.
|
"""Derive state from action with a per-episode offset."""
|
||||||
|
|
||||||
state[t] = action[max(0, t - offset)] (clamped to episode start)
|
|
||||||
"""
|
|
||||||
hf = dataset.hf_dataset
|
hf = dataset.hf_dataset
|
||||||
episode_indices = np.array(hf["episode_index"])
|
episode_indices = np.array(hf["episode_index"])
|
||||||
frame_indices = np.array(hf["frame_index"])
|
frame_indices = np.array(hf["frame_index"])
|
||||||
@@ -134,19 +156,32 @@ def main():
|
|||||||
action_meta = dataset.features["action"]
|
action_meta = dataset.features["action"]
|
||||||
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
|
logger.info(f"Action shape: {action_meta['shape']}, names: {action_meta.get('names')}")
|
||||||
|
|
||||||
if STATE_SOURCE_FEATURE is not None:
|
if STATE_SOURCE_FEATURES is not None:
|
||||||
if STATE_SOURCE_FEATURE not in dataset.features:
|
source_list = (
|
||||||
raise ValueError(
|
[STATE_SOURCE_FEATURES] if isinstance(STATE_SOURCE_FEATURES, str) else list(STATE_SOURCE_FEATURES)
|
||||||
f"Source feature '{STATE_SOURCE_FEATURE}' not found. "
|
)
|
||||||
f"Available: {list(dataset.features.keys())}"
|
for feat in source_list:
|
||||||
)
|
if feat not in dataset.features:
|
||||||
source_meta = dataset.features[STATE_SOURCE_FEATURE]
|
raise ValueError(f"Feature '{feat}' not found. Available: {list(dataset.features.keys())}")
|
||||||
logger.info(f"Copying {STATE_SOURCE_FEATURE} → observation.state")
|
|
||||||
state_fn = _build_state_from_feature(dataset, STATE_SOURCE_FEATURE)
|
# Compute combined shape and names
|
||||||
|
total_dim = 0
|
||||||
|
all_names = []
|
||||||
|
for feat in source_list:
|
||||||
|
meta = dataset.features[feat]
|
||||||
|
total_dim += meta["shape"][0]
|
||||||
|
names = meta.get("names")
|
||||||
|
if names:
|
||||||
|
all_names.extend(names)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Concatenating {source_list} → observation.state (shape=[{total_dim}], names={all_names})"
|
||||||
|
)
|
||||||
|
state_fn = _build_state_from_features(dataset, source_list)
|
||||||
state_feature_info = {
|
state_feature_info = {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": list(source_meta["shape"]),
|
"shape": [total_dim],
|
||||||
"names": source_meta.get("names"),
|
"names": all_names or None,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
|
logger.info(f"Deriving observation.state from action with offset={STATE_ACTION_OFFSET}")
|
||||||
@@ -160,6 +195,7 @@ def main():
|
|||||||
augmented = add_features(
|
augmented = add_features(
|
||||||
dataset,
|
dataset,
|
||||||
features={"observation.state": (state_fn, state_feature_info)},
|
features={"observation.state": (state_fn, state_feature_info)},
|
||||||
|
repo_id=OUTPUT_REPO_ID,
|
||||||
)
|
)
|
||||||
logger.info("observation.state added")
|
logger.info("observation.state added")
|
||||||
|
|
||||||
@@ -168,13 +204,14 @@ def main():
|
|||||||
augmented.push_to_hub()
|
augmented.push_to_hub()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Done. Now recompute relative action stats:\n"
|
f"Done. Dataset at: {augmented.root}\n"
|
||||||
|
"Now recompute relative action stats:\n"
|
||||||
" lerobot-edit-dataset \\\n"
|
" lerobot-edit-dataset \\\n"
|
||||||
f" --repo_id {augmented.repo_id} \\\n"
|
f" --repo_id {augmented.repo_id} \\\n"
|
||||||
" --operation.type recompute_stats \\\n"
|
" --operation.type recompute_stats \\\n"
|
||||||
" --operation.relative_action true \\\n"
|
" --operation.relative_action true \\\n"
|
||||||
" --operation.chunk_size 50 \\\n"
|
" --operation.chunk_size 50 \\\n"
|
||||||
" --operation.relative_exclude_joints \"['gripper']\" \\\n"
|
' --operation.relative_exclude_joints "[]" \\\n'
|
||||||
" --push_to_hub true"
|
" --push_to_hub true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user