mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6bbc24a991 | |||
| 0394fae446 | |||
| 602b8e66a6 | |||
| ab4dce6fed | |||
| 40f4386e4a | |||
| 87a91b4b08 | |||
| fadb900c36 | |||
| de0663226a | |||
| 0ca9d66cae | |||
| 2222f25da3 | |||
| acae8417aa | |||
| 2697f65cf6 | |||
| 74f42f218e | |||
| ca9d49e305 | |||
| 6705876d47 | |||
| aadbd27675 | |||
| 5221647b5e | |||
| 9c981300dd | |||
| f5b27aad1b | |||
| 75f1285507 | |||
| 33cedc2f71 | |||
| aa32e6c4ab | |||
| f906270ec4 | |||
| 733b6d84db | |||
| 8abc9037a3 | |||
| e4d4ac0bda | |||
| e79b2a439b | |||
| f9ae78ca74 | |||
| e1ced538e3 | |||
| 2a98602ad6 | |||
| a2f5b3571e | |||
| cecf2eff4f | |||
| 7e6b598a51 | |||
| 4fa41ba806 | |||
| 1de2b87a92 | |||
| e3c511db67 | |||
| aed4130d39 | |||
| d26349c692 | |||
| a9bce4732b | |||
| 86d69e3c1d |
@@ -25,6 +25,7 @@ This module provides utilities for:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
@@ -37,7 +38,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
@@ -45,6 +46,8 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
flatten_dict,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
update_chunk_file_indices,
|
||||
@@ -141,6 +144,315 @@ def delete_episodes(
|
||||
return new_dataset
|
||||
|
||||
|
||||
def trim_episode_start(
|
||||
dataset: LeRobotDataset,
|
||||
seconds: float,
|
||||
episode_indices: list[int] | None = None,
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim the first N seconds from selected episodes and create a new dataset.
|
||||
|
||||
The operation rewrites data parquet files and updates episode metadata so that:
|
||||
- frame_index starts at 0 for each trimmed episode
|
||||
- timestamp starts at 0 for each trimmed episode
|
||||
- global index remains contiguous across the full dataset
|
||||
- dataset_from_index / dataset_to_index reflect new frame ranges
|
||||
|
||||
Video files are copied as-is and per-episode video timestamps are shifted forward
|
||||
for trimmed episodes.
|
||||
|
||||
Episodes selected for trimming that are too short (length <= trim_frames) are skipped
|
||||
from the output dataset.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
seconds: Number of seconds to remove from episode starts.
|
||||
episode_indices: Optional list of episode indices to trim. If None, trims all episodes.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
"""
|
||||
if seconds <= 0:
|
||||
raise ValueError(f"seconds must be strictly positive, got {seconds}")
|
||||
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.meta.root)
|
||||
|
||||
trim_frames = int(seconds * dataset.meta.fps)
|
||||
if trim_frames <= 0:
|
||||
raise ValueError(
|
||||
f"seconds={seconds} corresponds to 0 frames at fps={dataset.meta.fps}. "
|
||||
"Increase seconds so at least one frame is trimmed."
|
||||
)
|
||||
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if len(episode_indices) == 0:
|
||||
raise ValueError("No episodes specified to trim")
|
||||
|
||||
episode_indices = sorted(set(episode_indices))
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_indices) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
too_short = sorted(
|
||||
ep_idx for ep_idx in episode_indices if int(dataset.meta.episodes[ep_idx]["length"]) <= trim_frames
|
||||
)
|
||||
trim_set = set(episode_indices)
|
||||
skipped_set = set(too_short)
|
||||
trim_set -= skipped_set
|
||||
|
||||
if too_short:
|
||||
logging.warning(
|
||||
f"Skipping {len(too_short)} episode(s) that are too short to trim "
|
||||
f"({trim_frames} frames): {too_short}"
|
||||
)
|
||||
|
||||
episodes_to_keep = [ep_idx for ep_idx in range(dataset.meta.total_episodes) if ep_idx not in skipped_set]
|
||||
if not episodes_to_keep:
|
||||
raise ValueError(
|
||||
"All episodes selected for trimming are too short and would be skipped. "
|
||||
"Try a smaller trim duration."
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Trimming {len(trim_set)} episode(s) by {seconds}s and keeping {len(episodes_to_keep)} "
|
||||
f"episode(s) in output"
|
||||
)
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
new_meta.tasks = dataset.meta.tasks.copy()
|
||||
|
||||
subtasks_path = dataset.root / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
dst_subtasks_path = new_meta.root / DEFAULT_SUBTASKS_PATH
|
||||
dst_subtasks_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(subtasks_path, dst_subtasks_path)
|
||||
|
||||
episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)}
|
||||
trim_duration_s = trim_frames / dataset.meta.fps
|
||||
|
||||
episode_lengths: dict[int, int] = {}
|
||||
episode_ranges: dict[int, tuple[int, int]] = {}
|
||||
total_frames = 0
|
||||
for old_ep_idx in episodes_to_keep:
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
src_length = int(dataset.meta.episodes[old_ep_idx]["length"])
|
||||
new_length = src_length - trim_frames if old_ep_idx in trim_set else src_length
|
||||
episode_lengths[new_ep_idx] = new_length
|
||||
episode_ranges[new_ep_idx] = (total_frames, total_frames + new_length)
|
||||
total_frames += new_length
|
||||
|
||||
numeric_features = {
|
||||
k: v
|
||||
for k, v in dataset.meta.features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
}
|
||||
episode_stats_parts: dict[int, list[dict[str, dict]]] = defaultdict(list)
|
||||
episode_file_metadata: dict[int, dict[str, int]] = {}
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Trimming data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
if skipped_set:
|
||||
keep_mask = ~df["episode_index"].isin(skipped_set)
|
||||
if not keep_mask.all():
|
||||
df = df.loc[keep_mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
if trim_set:
|
||||
trim_mask = df["episode_index"].isin(trim_set) & (df["frame_index"] < trim_frames)
|
||||
if trim_mask.any():
|
||||
df = df.loc[~trim_mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
relative_path = src_path.relative_to(dataset.root)
|
||||
chunk_idx = int(relative_path.parts[1].split("-")[1])
|
||||
file_idx = int(relative_path.parts[2].split("-")[1].split(".")[0])
|
||||
|
||||
for old_ep_idx in sorted(df["episode_index"].unique().tolist()):
|
||||
ep_mask = df["episode_index"] == old_ep_idx
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trim_set:
|
||||
df.loc[ep_mask, "frame_index"] = df.loc[ep_mask, "frame_index"] - trim_frames
|
||||
shifted_timestamps = df.loc[ep_mask, "timestamp"].to_numpy(dtype=np.float64) - trim_duration_s
|
||||
df.loc[ep_mask, "timestamp"] = np.clip(shifted_timestamps, a_min=0.0, a_max=None)
|
||||
|
||||
df.loc[ep_mask, "episode_index"] = new_ep_idx
|
||||
|
||||
ep_start, _ = episode_ranges[new_ep_idx]
|
||||
new_indices = ep_start + df.loc[ep_mask, "frame_index"].to_numpy(dtype=np.int64)
|
||||
df.loc[ep_mask, "index"] = new_indices
|
||||
|
||||
if new_ep_idx in episode_file_metadata:
|
||||
existing = episode_file_metadata[new_ep_idx]
|
||||
if (
|
||||
existing["data/chunk_index"] != chunk_idx
|
||||
or existing["data/file_index"] != file_idx
|
||||
):
|
||||
raise ValueError(
|
||||
f"Episode {old_ep_idx} spans multiple data files. "
|
||||
"trim_episode_start currently expects one data file per episode."
|
||||
)
|
||||
else:
|
||||
episode_file_metadata[new_ep_idx] = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
}
|
||||
|
||||
if numeric_features:
|
||||
ep_df = df.loc[ep_mask]
|
||||
episode_data: dict[str, np.ndarray] = {}
|
||||
episode_feature_spec: dict[str, dict] = {}
|
||||
|
||||
for key, feature in numeric_features.items():
|
||||
if key not in ep_df.columns:
|
||||
continue
|
||||
|
||||
values = ep_df[key].to_numpy()
|
||||
if len(values) == 0:
|
||||
continue
|
||||
|
||||
first_value = values[0]
|
||||
if isinstance(first_value, np.ndarray):
|
||||
episode_data[key] = np.stack(values)
|
||||
elif isinstance(first_value, (list, tuple)):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.asarray(values)
|
||||
|
||||
episode_feature_spec[key] = feature
|
||||
|
||||
if episode_data:
|
||||
episode_stats_parts[new_ep_idx].append(
|
||||
compute_episode_stats(episode_data, episode_feature_spec)
|
||||
)
|
||||
|
||||
df["index"] = df["index"].astype(np.int64)
|
||||
if "frame_index" in df.columns:
|
||||
df["frame_index"] = df["frame_index"].astype(np.int64)
|
||||
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_write_parquet(df, dst_path, new_meta)
|
||||
|
||||
all_episode_stats = []
|
||||
for old_ep_idx in tqdm(episodes_to_keep, desc="Writing episode metadata"):
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
|
||||
if new_ep_idx not in episode_file_metadata:
|
||||
raise ValueError(f"Missing data file metadata for episode {old_ep_idx}")
|
||||
|
||||
from_idx, to_idx = episode_ranges[new_ep_idx]
|
||||
src_episode = dataset.meta.episodes[old_ep_idx]
|
||||
ep_data_meta = episode_file_metadata[new_ep_idx]
|
||||
|
||||
stats_parts = episode_stats_parts.get(new_ep_idx, [])
|
||||
ep_stats = aggregate_stats(stats_parts) if len(stats_parts) > 1 else (stats_parts[0] if stats_parts else {})
|
||||
if ep_stats:
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
episode_meta = {
|
||||
"data/chunk_index": ep_data_meta["data/chunk_index"],
|
||||
"data/file_index": ep_data_meta["data/file_index"],
|
||||
"dataset_from_index": from_idx,
|
||||
"dataset_to_index": to_idx,
|
||||
}
|
||||
|
||||
for video_key in dataset.meta.video_keys:
|
||||
from_ts = src_episode[f"videos/{video_key}/from_timestamp"]
|
||||
if old_ep_idx in trim_set:
|
||||
from_ts += trim_duration_s
|
||||
episode_meta.update(
|
||||
{
|
||||
f"videos/{video_key}/chunk_index": src_episode[f"videos/{video_key}/chunk_index"],
|
||||
f"videos/{video_key}/file_index": src_episode[f"videos/{video_key}/file_index"],
|
||||
f"videos/{video_key}/from_timestamp": from_ts,
|
||||
f"videos/{video_key}/to_timestamp": src_episode[f"videos/{video_key}/to_timestamp"],
|
||||
}
|
||||
)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": new_ep_idx,
|
||||
"tasks": src_episode["tasks"],
|
||||
"length": episode_lengths[new_ep_idx],
|
||||
}
|
||||
episode_dict.update(episode_meta)
|
||||
if ep_stats:
|
||||
episode_dict.update(flatten_dict({"stats": ep_stats}))
|
||||
|
||||
new_meta._save_episode_metadata(episode_dict)
|
||||
|
||||
new_meta._close_writer()
|
||||
|
||||
if new_meta.video_keys:
|
||||
_copy_videos(dataset, new_meta)
|
||||
|
||||
new_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episodes_to_keep),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(new_meta.tasks) if new_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episodes_to_keep)}"},
|
||||
}
|
||||
)
|
||||
|
||||
if new_meta.video_keys and dataset.meta.video_keys:
|
||||
for key in new_meta.video_keys:
|
||||
if key in dataset.meta.features:
|
||||
new_meta.info["features"][key]["info"] = dataset.meta.info["features"][key].get("info", {})
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
merged_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in merged_stats:
|
||||
merged_stats[key] = value
|
||||
if merged_stats:
|
||||
write_stats(merged_stats, new_meta.root)
|
||||
|
||||
return LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, float | list[int]],
|
||||
@@ -1522,6 +1834,122 @@ def modify_tasks(
|
||||
return dataset
|
||||
|
||||
|
||||
def recompute_stats(
|
||||
dataset: LeRobotDataset,
|
||||
skip_image_video: bool = True,
|
||||
delta_action: bool = False,
|
||||
delta_exclude_joints: list[str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to recompute stats for.
|
||||
skip_image_video: If True (default), only recompute stats for numeric features
|
||||
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||
delta_action: If True, compute action stats as delta (action - state).
|
||||
Useful when training with use_delta_actions=True so normalization matches.
|
||||
delta_exclude_joints: Joint names to exclude from delta conversion when
|
||||
delta_action=True. These dims keep absolute stats. Uses dataset's
|
||||
action feature names to build the mask. Default: ["gripper"].
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
features = dataset.meta.features
|
||||
numeric_features = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
if skip_image_video:
|
||||
features_to_compute = numeric_features
|
||||
else:
|
||||
features_to_compute = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] != "string"
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
# Build delta mask if delta_action is enabled
|
||||
delta_mask = None
|
||||
if delta_action and "action" in features and "observation.state" in features:
|
||||
if delta_exclude_joints is None:
|
||||
delta_exclude_joints = ["gripper"]
|
||||
action_names = features["action"].get("names")
|
||||
if action_names is not None:
|
||||
exclude = set(delta_exclude_joints)
|
||||
delta_mask = [n not in exclude for n in action_names]
|
||||
else:
|
||||
action_dim = features["action"]["shape"][0]
|
||||
delta_mask = [True] * action_dim
|
||||
# Only recompute action stats when delta is enabled — state stays unchanged
|
||||
features_to_compute = {"action": features["action"]}
|
||||
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
|
||||
else:
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
# Also need state for delta computation even though we don't recompute state stats
|
||||
needs_state = delta_mask is not None
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
for ep_idx in sorted(df["episode_index"].unique()):
|
||||
ep_df = df[df["episode_index"] == ep_idx]
|
||||
episode_data = {}
|
||||
for key in numeric_keys:
|
||||
if key in ep_df.columns:
|
||||
values = ep_df[key].values
|
||||
if hasattr(values[0], "__len__"):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
# Apply delta conversion to actions before computing stats
|
||||
if delta_mask is not None and "action" in episode_data:
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
# Load state for delta even if we're not computing state stats
|
||||
if needs_state and "observation.state" in ep_df.columns:
|
||||
state_values = ep_df["observation.state"].values
|
||||
if hasattr(state_values[0], "__len__"):
|
||||
states = np.stack(state_values)
|
||||
else:
|
||||
states = np.array(state_values)
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(states).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
if not all_episode_stats:
|
||||
logging.warning("No episode stats computed")
|
||||
return dataset
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats)
|
||||
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
write_stats(new_stats, dataset.root)
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -470,6 +470,13 @@ def make_policy(
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
|
||||
# Store action feature names for delta_exclude_joints support
|
||||
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
delta_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -129,10 +131,19 @@ def make_pi05_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
delta_step = DeltaActionsProcessorStep(
|
||||
enabled=config.use_delta_actions,
|
||||
exclude_joints=getattr(config, "delta_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
delta_step,
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
@@ -154,6 +165,7 @@ def make_pi05_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 256
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
||||
)
|
||||
|
||||
if self.config.use_delta_actions and OBS_STATE in batch:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
continuous_actions = to_absolute_actions(
|
||||
continuous_actions, state, [True] * continuous_actions.shape[-1]
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
|
||||
@@ -28,7 +28,14 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
@@ -97,6 +104,8 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"DeltaActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -126,6 +135,8 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_delta_actions",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,54 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to delta: delta = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@@ -141,3 +183,126 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class DeltaActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||||
the conversion during postprocessing.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the delta conversion.
|
||||
exclude_joints: Joint names to keep absolute (not converted to delta).
|
||||
action_names: Action dimension names from dataset metadata, used to build
|
||||
the mask from exclude_joints. If None, all dims are converted.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
action_names: list[str] | None = None
|
||||
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
self._last_state = state
|
||||
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None or state is None:
|
||||
return new_transition
|
||||
|
||||
mask = self._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||||
predicted deltas are converted back to absolute positions for execution.
|
||||
Reads the cached state from its paired DeltaActionsProcessorStep.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the absolute conversion.
|
||||
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
if self.delta_step is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
|
||||
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
if self.delta_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
mask = self.delta_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.delta_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -331,11 +331,9 @@ class _NormalizationMixin:
|
||||
)
|
||||
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
return tensor * (std + 1e-6) + mean
|
||||
return (tensor - mean) / (std + 1e-6)
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min", None)
|
||||
@@ -367,11 +365,7 @@ class _NormalizationMixin:
|
||||
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
|
||||
)
|
||||
|
||||
denom = q99 - q01
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
denom = q99 - q01 + 1e-6
|
||||
if inverse:
|
||||
return (tensor + 1.0) * denom / 2.0 + q01
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
@@ -85,6 +85,13 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Trim first 3 seconds from all episodes:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trim3s \
|
||||
--operation.type trim_episode_start \
|
||||
--operation.seconds 3.0
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -125,6 +132,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
@@ -169,6 +177,13 @@ class ModifyTasksConfig(OperationConfig):
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("trim_episode_start")
|
||||
@dataclass
|
||||
class TrimEpisodeStartConfig(OperationConfig):
|
||||
seconds: float | None = None
|
||||
episode_indices: list[int] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("convert_image_to_video")
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
@@ -373,6 +388,41 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
modified_dataset.push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode_start(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, TrimEpisodeStartConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeStartConfig")
|
||||
|
||||
if cfg.operation.seconds is None:
|
||||
raise ValueError("seconds must be specified for trim_episode_start operation")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(
|
||||
f"Trimming first {cfg.operation.seconds}s from episodes "
|
||||
f"{cfg.operation.episode_indices if cfg.operation.episode_indices else 'ALL'} in {cfg.repo_id}"
|
||||
)
|
||||
new_dataset = trim_episode_start(
|
||||
dataset=dataset,
|
||||
seconds=cfg.operation.seconds,
|
||||
episode_indices=cfg.operation.episode_indices,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
@@ -450,6 +500,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "modify_tasks":
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "trim_episode_start":
|
||||
handle_trim_episode_start(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual robot dataset by swapping left/right arms and inverting joint values.
|
||||
|
||||
This script creates a mirrored version of a dataset where:
|
||||
1. Left and right arm observations/actions are swapped
|
||||
2. Joint values are inverted according to a mirroring mask
|
||||
3. Video frames are horizontally flipped
|
||||
|
||||
Example usage:
|
||||
```shell
|
||||
python -m lerobot.scripts.lerobot_mirror_dataset \
|
||||
--repo_id=pepijn/openarm_bimanual \
|
||||
--output_repo_id=pepijn/openarm_bimanual_mirrored
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
DEFAULT_DATA_PATH,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1, # Pan - invert
|
||||
"joint_2": -1, # Lift - invert
|
||||
"joint_3": -1, # Roll - invert
|
||||
"joint_4": 1, # Elbow - no invert
|
||||
"joint_5": -1, # W-Roll - invert
|
||||
"joint_6": -1, # W-Pitch - invert
|
||||
"joint_7": -1, # W-Yaw - invert
|
||||
"gripper": 1, # Gripper - no invert
|
||||
}
|
||||
|
||||
|
||||
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
|
||||
"""Get the mirroring mask for a given robot type."""
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return OPENARM_MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
"""Swap 'left' and 'right' in a feature name."""
|
||||
# Use placeholder to avoid double-swap
|
||||
result = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
result = result.replace("right_", "left_")
|
||||
result = result.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return result
|
||||
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
"""Mirror feature names by swapping left/right and return the new names and index mapping."""
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
|
||||
def apply_mirroring_mask(
|
||||
value: float,
|
||||
feature_name: str,
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> float:
|
||||
"""Apply mirroring mask to a joint value."""
|
||||
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
|
||||
joint_name = name_without_prefix.split(".")[0]
|
||||
if joint_name in mirroring_mask:
|
||||
return value * mirroring_mask[joint_name]
|
||||
return value
|
||||
|
||||
|
||||
def mirror_array(
|
||||
array: np.ndarray,
|
||||
names: list[str],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> np.ndarray:
|
||||
"""Mirror an array of values (action or state) by swapping left/right and applying mask."""
|
||||
mirrored_names, idx_mapping = mirror_feature_names(names)
|
||||
result = np.zeros_like(array)
|
||||
for old_idx, new_idx in idx_mapping.items():
|
||||
old_name = names[old_idx]
|
||||
new_name = mirrored_names[new_idx]
|
||||
value = array[old_idx]
|
||||
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
|
||||
result[new_idx] = mirrored_value
|
||||
return result
|
||||
|
||||
|
||||
def flip_video_frames(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
"""Flip video frames horizontally using FFmpeg with same settings as encode_video_frames."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", str(input_path),
|
||||
"-vf", "hflip",
|
||||
"-c:v", vcodec,
|
||||
"-g", "2",
|
||||
"-crf", "30",
|
||||
"-r", str(int(fps)),
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-loglevel", "error",
|
||||
]
|
||||
if vcodec == "libsvtav1":
|
||||
cmd.extend(["-preset", "12"])
|
||||
cmd.append(str(output_path))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg failed: {result.stderr}")
|
||||
|
||||
|
||||
def mirror_dataset(
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
output_root: str | Path | None = None,
|
||||
mirroring_mask: dict[str, int] | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Mirror a bimanual robot dataset."""
|
||||
logger.info(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if mirroring_mask is None:
|
||||
robot_type = dataset.meta.robot_type or "bi_openarms_follower"
|
||||
mirroring_mask = get_mirroring_mask(robot_type)
|
||||
logger.info(f"Using mirroring mask for robot type: {robot_type}")
|
||||
|
||||
output_root = Path(output_root) if output_root else HF_LEROBOT_HOME / output_repo_id
|
||||
|
||||
mirrored_features = {}
|
||||
for key, feat in dataset.meta.features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feat = feat.copy()
|
||||
if "names" in new_feat and new_feat["names"]:
|
||||
new_feat["names"] = [swap_left_right_name(n) for n in new_feat["names"]]
|
||||
mirrored_features[new_key] = new_feat
|
||||
|
||||
logger.info("Creating mirrored dataset metadata...")
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=output_repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_root,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
new_meta.tasks = dataset.meta.tasks.copy()
|
||||
|
||||
_mirror_data(dataset, new_meta, mirroring_mask)
|
||||
_mirror_videos(dataset, new_meta, vcodec, num_workers)
|
||||
_copy_episodes_metadata(dataset, new_meta)
|
||||
|
||||
logger.info(f"Mirrored dataset saved to: {output_root}")
|
||||
return LeRobotDataset(output_repo_id, root=output_root)
|
||||
|
||||
|
||||
def _mirror_data(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> None:
|
||||
"""Mirror parquet data files."""
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
action_names = src_dataset.meta.features.get("action", {}).get("names", [])
|
||||
state_names = src_dataset.meta.features.get("observation.state", {}).get("names", [])
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Mirroring data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
if "action" in df.columns and action_names:
|
||||
actions = np.stack(df["action"].values)
|
||||
mirrored_actions = np.array([
|
||||
mirror_array(row, action_names, mirroring_mask) for row in actions
|
||||
])
|
||||
df["action"] = list(mirrored_actions)
|
||||
|
||||
if "observation.state" in df.columns and state_names:
|
||||
states = np.stack(df["observation.state"].values)
|
||||
mirrored_states = np.array([
|
||||
mirror_array(row, state_names, mirroring_mask) for row in states
|
||||
])
|
||||
df["observation.state"] = list(mirrored_states)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def _mirror_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
vcodec: str,
|
||||
num_workers: int | None = None,
|
||||
) -> None:
|
||||
"""Mirror video files by flipping horizontally and swapping left/right names."""
|
||||
if not src_dataset.meta.video_keys:
|
||||
return
|
||||
|
||||
video_tasks = []
|
||||
for old_video_key in src_dataset.meta.video_keys:
|
||||
new_video_key = swap_left_right_name(old_video_key)
|
||||
for ep_idx in range(src_dataset.meta.total_episodes):
|
||||
try:
|
||||
src_path = src_dataset.root / src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative = src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
|
||||
dst_relative_str = str(dst_relative).replace(old_video_key, new_video_key)
|
||||
dst_path = dst_meta.root / dst_relative_str
|
||||
if src_path.exists():
|
||||
video_tasks.append((src_path, dst_path))
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
def process_video(task, pbar):
|
||||
src_path, dst_path = task
|
||||
pbar.set_postfix_str(src_path.name)
|
||||
flip_video_frames(src_path, dst_path, src_dataset.meta.fps, vcodec)
|
||||
return src_path
|
||||
|
||||
if num_workers is None:
|
||||
num_workers = os.cpu_count() or 16
|
||||
num_workers = min(len(video_tasks), num_workers)
|
||||
logger.info(f"Processing {len(video_tasks)} videos with {num_workers} workers")
|
||||
with tqdm(total=len(video_tasks), desc="Mirroring videos") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {executor.submit(process_video, t, pbar): t for t in video_tasks}
|
||||
for future in as_completed(futures):
|
||||
task = futures[future]
|
||||
future.result()
|
||||
pbar.set_postfix_str(f"done: {task[0].name}")
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
def _copy_episodes_metadata(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
) -> None:
|
||||
"""Copy episodes metadata with swapped video keys."""
|
||||
episodes_dir = src_dataset.root / "meta/episodes"
|
||||
dst_episodes_dir = dst_meta.root / "meta/episodes"
|
||||
|
||||
if episodes_dir.exists():
|
||||
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
for src_parquet in episodes_dir.glob("*/*.parquet"):
|
||||
df = pd.read_parquet(src_parquet)
|
||||
columns_to_rename = {}
|
||||
for col in df.columns:
|
||||
if col.startswith("videos/"):
|
||||
parts = col.split("/")
|
||||
if len(parts) >= 2:
|
||||
video_key = parts[1]
|
||||
new_video_key = swap_left_right_name(video_key)
|
||||
new_col = col.replace(f"videos/{video_key}/", f"videos/{new_video_key}/")
|
||||
columns_to_rename[col] = new_col
|
||||
if columns_to_rename:
|
||||
df = df.rename(columns=columns_to_rename)
|
||||
dst_parquet = dst_episodes_dir / src_parquet.relative_to(episodes_dir)
|
||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_parquet, index=False)
|
||||
|
||||
dst_meta.info.update({
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"total_videos": src_dataset.meta.total_videos,
|
||||
"total_chunks": src_dataset.meta.total_chunks,
|
||||
})
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
if src_dataset.meta.stats is not None:
|
||||
mirrored_stats = _mirror_stats(src_dataset.meta.stats)
|
||||
write_stats(mirrored_stats, dst_meta.root)
|
||||
|
||||
|
||||
def _mirror_stats(stats: dict) -> dict:
|
||||
"""Mirror stats by swapping left/right in feature names."""
|
||||
mirrored = {}
|
||||
for key, value in stats.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
if isinstance(value, dict):
|
||||
mirrored[new_key] = _mirror_stats(value)
|
||||
else:
|
||||
mirrored[new_key] = value
|
||||
return mirrored
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
parser = argparse.ArgumentParser(description="Mirror a bimanual robot dataset")
|
||||
parser.add_argument("--repo_id", type=str, required=True, help="Source dataset repo_id")
|
||||
parser.add_argument("--output_repo_id", type=str, required=True, help="Output dataset repo_id")
|
||||
parser.add_argument("--root", type=str, default=None, help="Source dataset root directory")
|
||||
parser.add_argument("--output_root", type=str, default=None, help="Output dataset root directory")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec (libsvtav1, h264, hevc)")
|
||||
parser.add_argument("--num_workers", type=int, default=None, help="Number of parallel workers for video processing")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = mirror_dataset(
|
||||
repo_id=args.repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
root=args.root,
|
||||
output_root=args.output_root,
|
||||
vcodec=args.vcodec,
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
if getattr(args, "push_to_hub", False):
|
||||
logger.info(f"Pushing dataset to HuggingFace Hub: {args.output_repo_id}")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -175,8 +175,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
@@ -211,16 +209,98 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
delta_action_stats = None
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Compute delta action stats BEFORE distributed sync to avoid NCCL timeout
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions
|
||||
|
||||
chunk_size = cfg.policy.chunk_size
|
||||
hf = dataset.hf_dataset
|
||||
total_frames = len(hf)
|
||||
sample_upper_bound = total_frames - chunk_size
|
||||
if sample_upper_bound <= 0:
|
||||
raise ValueError(
|
||||
f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}"
|
||||
)
|
||||
|
||||
max_samples = min(100_000, sample_upper_bound)
|
||||
indices = np.random.choice(sample_upper_bound, max_samples, replace=False)
|
||||
|
||||
action_names = dataset.meta.features.get("action", {}).get("names")
|
||||
delta_mask_step = DeltaActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []),
|
||||
action_names=action_names,
|
||||
)
|
||||
delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0])
|
||||
logging.info(
|
||||
f"use_delta_actions is enabled — computing delta action stats "
|
||||
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
|
||||
)
|
||||
|
||||
all_delta_actions = []
|
||||
episode_indices = np.array(hf["episode_index"])
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
ep_idx = episode_indices[idx]
|
||||
end_idx = min(idx + chunk_size, total_frames)
|
||||
if end_idx > idx and episode_indices[end_idx - 1] != ep_idx:
|
||||
continue
|
||||
|
||||
chunk_data = hf[idx:end_idx]
|
||||
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
||||
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
||||
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0)
|
||||
all_delta_actions.append(delta.numpy())
|
||||
|
||||
if not all_delta_actions:
|
||||
raise RuntimeError("Failed to compute delta action stats: no valid chunks found.")
|
||||
|
||||
all_delta = np.concatenate(all_delta_actions, axis=0)
|
||||
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
delta_action_stats = delta_stats
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
norm_type = "UNKNOWN"
|
||||
if hasattr(cfg.policy, "normalization_mapping"):
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
|
||||
norm_type = action_norm.value if action_norm else "UNKNOWN"
|
||||
|
||||
excluded_dims = len(delta_mask) - sum(delta_mask)
|
||||
logging.info(
|
||||
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
|
||||
f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
|
||||
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
|
||||
)
|
||||
if norm_type == "QUANTILES":
|
||||
q_range = (delta_stats['q99'] - delta_stats['q01']).mean()
|
||||
logging.info(f" Quantile range (q99-q01): {q_range:.4f}")
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Ensure all ranks use the exact same delta action stats.
|
||||
if getattr(cfg.policy, "use_delta_actions", False):
|
||||
if accelerator.num_processes > 1 and torch.distributed.is_initialized():
|
||||
stats_list = [delta_action_stats]
|
||||
torch.distributed.broadcast_object_list(stats_list, src=0)
|
||||
delta_action_stats = stats_list[0]
|
||||
if delta_action_stats is not None:
|
||||
dataset.meta.stats["action"] = delta_action_stats
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
@@ -246,10 +326,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(cfg.policy, "use_delta_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_delta_actions=true with pretrained processors can skip delta transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -257,7 +349,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -279,7 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
@@ -397,7 +489,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
|
||||
# Debug logging for first few steps and periodically
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None and state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] PRE-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f} | "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}"
|
||||
)
|
||||
|
||||
batch = preprocessor(batch)
|
||||
|
||||
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
|
||||
action = batch.get("action")
|
||||
state = batch.get("observation.state")
|
||||
if action is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
|
||||
f"min={action.min():.4f}, max={action.max():.4f}"
|
||||
)
|
||||
if state is not None:
|
||||
logging.info(
|
||||
f"[DEBUG step={step}] POST-PROCESSOR — "
|
||||
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}"
|
||||
)
|
||||
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
@@ -142,6 +143,104 @@ def test_delete_empty_list(sample_dataset, tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_trim_episode_start_updates_indices(sample_dataset, tmp_path):
|
||||
"""Test trimming episode starts updates frame/timestamp/index metadata consistently."""
|
||||
output_dir = tmp_path / "trimmed"
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
trim_frames = int(trim_seconds * sample_dataset.meta.fps)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
expected_length = 10 - trim_frames
|
||||
assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes
|
||||
assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length
|
||||
|
||||
indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]]
|
||||
assert indices == list(range(new_dataset.meta.total_frames))
|
||||
|
||||
episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]]
|
||||
frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]]
|
||||
timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]]
|
||||
|
||||
for ep_idx in range(sample_dataset.meta.total_episodes):
|
||||
ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx]
|
||||
ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx]
|
||||
|
||||
assert len(ep_frame_indices) == expected_length
|
||||
assert ep_frame_indices == list(range(expected_length))
|
||||
assert ep_timestamps[0] == pytest.approx(0.0)
|
||||
assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps)
|
||||
|
||||
ep_meta = new_dataset.meta.episodes[ep_idx]
|
||||
assert int(ep_meta["length"]) == expected_length
|
||||
assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length
|
||||
assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length
|
||||
|
||||
|
||||
def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test too-short episodes are skipped and remaining episodes are reindexed."""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features)
|
||||
|
||||
for ep_len in [10, 2, 10]:
|
||||
for _ in range(ep_len):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"observation.state": np.random.randn(2).astype(np.float32),
|
||||
"observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8),
|
||||
"task": "task",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "trimmed")
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
# Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed.
|
||||
assert new_dataset.meta.total_episodes == 2
|
||||
assert new_dataset.meta.total_frames == 14
|
||||
assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1]
|
||||
assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7]
|
||||
|
||||
|
||||
def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path):
|
||||
"""Test trimming fails when all selected episodes are too short and would be skipped."""
|
||||
with pytest.raises(ValueError, match="All episodes selected for trimming are too short"):
|
||||
trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=1.0, # 30 frames > 10-frame episodes
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
|
||||
def test_split_by_episodes(sample_dataset, tmp_path):
|
||||
"""Test splitting dataset by specific episode indices."""
|
||||
splits = {
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
"""Tests for delta action transforms — full pipeline validation.
|
||||
|
||||
Tests the complete flow matching OpenPI:
|
||||
raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions
|
||||
|
||||
Uses real dataset: lerobot-data-collection/dagger_final_1_21
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.delta_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dataset():
|
||||
return LeRobotDataset(REPO_ID, episodes=[0])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def action_dim(dataset):
|
||||
return dataset.meta.features["action"]["shape"][0]
|
||||
|
||||
|
||||
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
|
||||
"""Build action chunks from hf_dataset, like the training script does."""
|
||||
hf = dataset.hf_dataset
|
||||
total = len(hf)
|
||||
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
|
||||
chunks, states = [], []
|
||||
for i in range(total - chunk_size + 1):
|
||||
if all_ep[i] != all_ep[i + chunk_size - 1]:
|
||||
continue
|
||||
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
|
||||
state = hf[i]["observation.state"].float()
|
||||
chunks.append(chunk_actions)
|
||||
states.append(state)
|
||||
if len(chunks) >= max_chunks:
|
||||
break
|
||||
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
|
||||
return torch.stack(chunks), torch.stack(states)
|
||||
|
||||
|
||||
def _compute_delta_chunk_stats(action_chunks, states, mask):
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta.numpy())
|
||||
all_delta = np.concatenate(all_deltas, axis=0)
|
||||
return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
|
||||
|
||||
# --- Basic roundtrip tests ---
|
||||
|
||||
def test_roundtrip_3d(action_dim):
|
||||
actions = torch.randn(4, CHUNK_SIZE, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_roundtrip_2d(action_dim):
|
||||
actions = torch.randn(4, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_no_mutation(action_dim):
|
||||
actions = torch.randn(2, CHUNK_SIZE, action_dim)
|
||||
original = actions.clone()
|
||||
state = torch.randn(2, action_dim)
|
||||
to_delta_actions(actions, state, [True] * action_dim)
|
||||
torch.testing.assert_close(actions, original)
|
||||
|
||||
|
||||
def test_exclude_joints_supports_partial_name_matching():
|
||||
names = [
|
||||
"right_joint_1.pos",
|
||||
"right_gripper.pos",
|
||||
"left_joint_1.pos",
|
||||
"left_gripper.pos",
|
||||
]
|
||||
step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
|
||||
assert step._build_mask(len(names)) == [True, False, True, False]
|
||||
|
||||
|
||||
# --- Chunk-level delta stats test ---
|
||||
|
||||
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
|
||||
"""Chunk-level delta stats should have larger std than per-frame delta stats."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Per-frame stats
|
||||
hf = dataset.hf_dataset
|
||||
n = min(500, len(hf))
|
||||
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
|
||||
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
|
||||
frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy()
|
||||
frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1)
|
||||
|
||||
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
|
||||
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
|
||||
f"frame std ({frame_stats['std'].mean():.4f})"
|
||||
)
|
||||
|
||||
|
||||
# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute ---
|
||||
|
||||
def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||
"""Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized_action = t2[TransitionKey.ACTION]
|
||||
assert normalized_action.abs().mean() < 10, (
|
||||
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered_actions = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_normalized_delta_values_are_reasonable(dataset, action_dim):
|
||||
"""With correct chunk stats, normalized delta actions should be in a reasonable range."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
mean = torch.tensor(delta_stats["mean"]).float()
|
||||
std = torch.tensor(delta_stats["std"]).float()
|
||||
|
||||
all_normalized = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
normalized = (delta - mean) / (std + 1e-6)
|
||||
all_normalized.append(normalized)
|
||||
|
||||
all_normalized = torch.cat(all_normalized, dim=0)
|
||||
|
||||
pct_in_range = (all_normalized.abs() < 5).float().mean()
|
||||
assert pct_in_range > 0.9, (
|
||||
f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%"
|
||||
)
|
||||
|
||||
assert all_normalized.mean().abs() < 1.0, (
|
||||
f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0"
|
||||
)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset, action_dim):
|
||||
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_actions = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
delta_transition = step(transition)
|
||||
assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions)
|
||||
|
||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask)
|
||||
torch.testing.assert_close(recovered, original_actions)
|
||||
|
||||
|
||||
def test_processor_step_disabled_is_noop(dataset, action_dim):
|
||||
"""enabled=False should be a no-op."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
|
||||
}
|
||||
original = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||
|
||||
|
||||
# --- Training batch shape validation ---
|
||||
|
||||
def test_delta_with_action_chunks(dataset, action_dim):
|
||||
"""Verify delta works correctly with (B, chunk_size, action_dim) shaped actions."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
|
||||
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
|
||||
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
|
||||
batch_states = states[:4] # (4, state_dim)
|
||||
|
||||
mask = [True] * action_dim
|
||||
delta = to_delta_actions(batch_actions, batch_states, mask)
|
||||
|
||||
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
|
||||
first_deltas = delta[:, 0, :] # (B, action_dim)
|
||||
assert first_deltas.abs().mean() < delta.abs().mean(), (
|
||||
f"First action in chunk should have smaller delta than average. "
|
||||
f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Later actions should have larger deltas
|
||||
last_deltas = delta[:, -1, :] # (B, action_dim)
|
||||
assert last_deltas.abs().mean() >= first_deltas.abs().mean(), (
|
||||
f"Last action in chunk should have >= delta than first. "
|
||||
f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Roundtrip
|
||||
recovered = to_absolute_actions(delta, batch_states, mask)
|
||||
torch.testing.assert_close(recovered, batch_actions)
|
||||
|
||||
|
||||
def test_delta_stats_match_actual_data_distribution(dataset, action_dim):
|
||||
"""Verify computed stats match the actual delta distribution."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
# Compute stats like the training script does
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Also compute directly
|
||||
all_deltas = []
|
||||
for actions, state in zip(action_chunks, states):
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_deltas.append(delta)
|
||||
all_deltas_tensor = torch.cat(all_deltas, dim=0)
|
||||
|
||||
# Compare mean
|
||||
actual_mean = all_deltas_tensor.mean(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01)
|
||||
|
||||
# Compare std
|
||||
actual_std = all_deltas_tensor.std(dim=0).numpy()
|
||||
np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1)
|
||||
|
||||
# Verify q01 < mean < q99
|
||||
assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean"
|
||||
assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99"
|
||||
|
||||
|
||||
def test_quantile_normalization_roundtrip(dataset, action_dim):
|
||||
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
|
||||
|
||||
delta_step = DeltaActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: delta → quantile normalize
|
||||
t1 = delta_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized = t2[TransitionKey.ACTION]
|
||||
# Most values should be in [-1, 1] with quantile normalization
|
||||
pct_in_range = (normalized.abs() < 2).float().mean()
|
||||
assert pct_in_range > 0.5, (
|
||||
f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_state_not_modified_by_delta(dataset, action_dim):
|
||||
"""State should never be modified by the delta processor."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_state = batch[OBS_STATE].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
result = step(transition)
|
||||
|
||||
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
torch.testing.assert_close(result_state, original_state)
|
||||
@@ -26,6 +26,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
OperationConfig,
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
TrimEpisodeStartConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,6 +46,7 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
],
|
||||
)
|
||||
@@ -62,6 +64,7 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user