mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
add recomputation of stats and option to compute delta stats
This commit is contained in:
@@ -37,7 +37,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
@@ -1522,6 +1522,110 @@ def modify_tasks(
|
||||
return dataset
|
||||
|
||||
|
||||
def recompute_stats(
|
||||
dataset: LeRobotDataset,
|
||||
skip_image_video: bool = True,
|
||||
delta_action: bool = False,
|
||||
delta_exclude_joints: list[str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to recompute stats for.
|
||||
skip_image_video: If True (default), only recompute stats for numeric features
|
||||
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||
delta_action: If True, compute action stats as delta (action - state).
|
||||
Useful when training with use_delta_actions=True so normalization matches.
|
||||
delta_exclude_joints: Joint names to exclude from delta conversion when
|
||||
delta_action=True. These dims keep absolute stats. Uses dataset's
|
||||
action feature names to build the mask. Default: ["gripper"].
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
features = dataset.meta.features
|
||||
numeric_features = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
if skip_image_video:
|
||||
features_to_compute = numeric_features
|
||||
else:
|
||||
features_to_compute = {
|
||||
k: v for k, v in features.items()
|
||||
if v["dtype"] != "string"
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
# Build delta mask if delta_action is enabled
|
||||
delta_mask = None
|
||||
if delta_action and "action" in features and "observation.state" in features:
|
||||
if delta_exclude_joints is None:
|
||||
delta_exclude_joints = ["gripper"]
|
||||
action_names = features["action"].get("names")
|
||||
if action_names is not None:
|
||||
exclude = set(delta_exclude_joints)
|
||||
delta_mask = [n not in exclude for n in action_names]
|
||||
else:
|
||||
action_dim = features["action"]["shape"][0]
|
||||
delta_mask = [True] * action_dim
|
||||
logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
for ep_idx in sorted(df["episode_index"].unique()):
|
||||
ep_df = df[df["episode_index"] == ep_idx]
|
||||
episode_data = {}
|
||||
for key in numeric_keys:
|
||||
if key in ep_df.columns:
|
||||
values = ep_df[key].values
|
||||
if hasattr(values[0], "__len__"):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
# Apply delta conversion to actions before computing stats
|
||||
if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data:
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(episode_data["observation.state"]).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
if not all_episode_stats:
|
||||
logging.warning("No episode stats computed")
|
||||
return dataset
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats)
|
||||
|
||||
if skip_image_video and dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
write_stats(new_stats, dataset.root)
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -470,6 +470,13 @@ def make_policy(
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
|
||||
# Store action feature names on config for delta_exclude_joints support
|
||||
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
|
||||
@@ -50,8 +50,12 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state)
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta conversion (kept as absolute).
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime by make_policy from dataset metadata.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
@@ -1221,6 +1221,18 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def _build_delta_mask(self, action_dim: int) -> list[bool]:
|
||||
"""Build a boolean mask for delta action conversion.
|
||||
|
||||
Uses action_feature_names and delta_exclude_joints to determine which
|
||||
dims get delta conversion. Falls back to all-True if names are unavailable.
|
||||
"""
|
||||
names = self.config.action_feature_names
|
||||
if names is None:
|
||||
return [True] * action_dim
|
||||
exclude = set(self.config.delta_exclude_joints)
|
||||
return [n not in exclude for n in names]
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
@@ -1261,7 +1273,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
||||
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||
|
||||
return actions
|
||||
|
||||
@@ -1281,7 +1293,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
||||
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
@@ -52,6 +52,10 @@ class PI05Config(PreTrainedConfig):
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
# Joint names to exclude from delta conversion (kept as absolute).
|
||||
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime by make_policy from dataset metadata.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
@@ -1201,6 +1201,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _build_delta_mask(self, action_dim: int) -> list[bool]:
|
||||
"""Build a boolean mask for delta action conversion."""
|
||||
names = self.config.action_feature_names
|
||||
if names is None:
|
||||
return [True] * action_dim
|
||||
exclude = set(self.config.delta_exclude_joints)
|
||||
return [n not in exclude for n in names]
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
@@ -1236,7 +1244,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
||||
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||
|
||||
return actions
|
||||
|
||||
@@ -1257,7 +1265,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
||||
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
@@ -243,6 +243,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Recompute action stats as delta if use_delta_actions is enabled
|
||||
if getattr(cfg.policy, "use_delta_actions", False) and is_main_process:
|
||||
logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)")
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
|
||||
exclude = getattr(cfg.policy, "delta_exclude_joints", ["gripper"])
|
||||
recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user