mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
add recomputation of stats and option to compute delta stats
This commit is contained in:
@@ -37,7 +37,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.datasets.aggregate import aggregate_datasets
|
from lerobot.datasets.aggregate import aggregate_datasets
|
||||||
from lerobot.datasets.compute_stats import aggregate_stats
|
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
@@ -1522,6 +1522,110 @@ def modify_tasks(
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def recompute_stats(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
skip_image_video: bool = True,
|
||||||
|
delta_action: bool = False,
|
||||||
|
delta_exclude_joints: list[str] | None = None,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
"""Recompute stats.json from scratch by iterating all episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The LeRobotDataset to recompute stats for.
|
||||||
|
skip_image_video: If True (default), only recompute stats for numeric features
|
||||||
|
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||||
|
delta_action: If True, compute action stats as delta (action - state).
|
||||||
|
Useful when training with use_delta_actions=True so normalization matches.
|
||||||
|
delta_exclude_joints: Joint names to exclude from delta conversion when
|
||||||
|
delta_action=True. These dims keep absolute stats. Uses dataset's
|
||||||
|
action feature names to build the mask. Default: ["gripper"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same dataset with updated stats.
|
||||||
|
"""
|
||||||
|
features = dataset.meta.features
|
||||||
|
numeric_features = {
|
||||||
|
k: v for k, v in features.items()
|
||||||
|
if v["dtype"] not in ["image", "video", "string"]
|
||||||
|
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||||
|
}
|
||||||
|
|
||||||
|
if skip_image_video:
|
||||||
|
features_to_compute = numeric_features
|
||||||
|
else:
|
||||||
|
features_to_compute = {
|
||||||
|
k: v for k, v in features.items()
|
||||||
|
if v["dtype"] != "string"
|
||||||
|
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||||
|
|
||||||
|
# Build delta mask if delta_action is enabled
|
||||||
|
delta_mask = None
|
||||||
|
if delta_action and "action" in features and "observation.state" in features:
|
||||||
|
if delta_exclude_joints is None:
|
||||||
|
delta_exclude_joints = ["gripper"]
|
||||||
|
action_names = features["action"].get("names")
|
||||||
|
if action_names is not None:
|
||||||
|
exclude = set(delta_exclude_joints)
|
||||||
|
delta_mask = [n not in exclude for n in action_names]
|
||||||
|
else:
|
||||||
|
action_dim = features["action"]["shape"][0]
|
||||||
|
delta_mask = [True] * action_dim
|
||||||
|
logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})")
|
||||||
|
|
||||||
|
data_dir = dataset.root / DATA_DIR
|
||||||
|
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||||
|
if not parquet_files:
|
||||||
|
raise ValueError(f"No parquet files found in {data_dir}")
|
||||||
|
|
||||||
|
all_episode_stats = []
|
||||||
|
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||||
|
|
||||||
|
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
|
||||||
|
for ep_idx in sorted(df["episode_index"].unique()):
|
||||||
|
ep_df = df[df["episode_index"] == ep_idx]
|
||||||
|
episode_data = {}
|
||||||
|
for key in numeric_keys:
|
||||||
|
if key in ep_df.columns:
|
||||||
|
values = ep_df[key].values
|
||||||
|
if hasattr(values[0], "__len__"):
|
||||||
|
episode_data[key] = np.stack(values)
|
||||||
|
else:
|
||||||
|
episode_data[key] = np.array(values)
|
||||||
|
|
||||||
|
# Apply delta conversion to actions before computing stats
|
||||||
|
if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data:
|
||||||
|
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||||
|
|
||||||
|
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||||
|
states_t = torch.from_numpy(episode_data["observation.state"]).float()
|
||||||
|
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||||
|
|
||||||
|
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||||
|
all_episode_stats.append(ep_stats)
|
||||||
|
|
||||||
|
if not all_episode_stats:
|
||||||
|
logging.warning("No episode stats computed")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
new_stats = aggregate_stats(all_episode_stats)
|
||||||
|
|
||||||
|
if skip_image_video and dataset.meta.stats:
|
||||||
|
for key, value in dataset.meta.stats.items():
|
||||||
|
if key not in new_stats:
|
||||||
|
new_stats[key] = value
|
||||||
|
|
||||||
|
write_stats(new_stats, dataset.root)
|
||||||
|
dataset.meta.stats = new_stats
|
||||||
|
|
||||||
|
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def convert_image_to_video_dataset(
|
def convert_image_to_video_dataset(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
|
|||||||
@@ -470,6 +470,13 @@ def make_policy(
|
|||||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||||
if not cfg.input_features:
|
if not cfg.input_features:
|
||||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||||
|
|
||||||
|
# Store action feature names on config for delta_exclude_joints support
|
||||||
|
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||||
|
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||||
|
if action_names is not None:
|
||||||
|
cfg.action_feature_names = list(action_names)
|
||||||
|
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
|
|
||||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||||
|
|||||||
@@ -50,8 +50,12 @@ class PI0Config(PreTrainedConfig):
|
|||||||
min_period: float = 4e-3
|
min_period: float = 4e-3
|
||||||
max_period: float = 4.0
|
max_period: float = 4.0
|
||||||
|
|
||||||
# Delta actions: converts absolute actions to delta (relative to state)
|
# Delta actions: converts absolute actions to delta (relative to state).
|
||||||
use_delta_actions: bool = False
|
use_delta_actions: bool = False
|
||||||
|
# Joint names to exclude from delta conversion (kept as absolute).
|
||||||
|
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||||
|
# Populated at runtime by make_policy from dataset metadata.
|
||||||
|
action_feature_names: list[str] | None = None
|
||||||
|
|
||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|||||||
@@ -1221,6 +1221,18 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def _build_delta_mask(self, action_dim: int) -> list[bool]:
|
||||||
|
"""Build a boolean mask for delta action conversion.
|
||||||
|
|
||||||
|
Uses action_feature_names and delta_exclude_joints to determine which
|
||||||
|
dims get delta conversion. Falls back to all-True if names are unavailable.
|
||||||
|
"""
|
||||||
|
names = self.config.action_feature_names
|
||||||
|
if names is None:
|
||||||
|
return [True] * action_dim
|
||||||
|
exclude = set(self.config.delta_exclude_joints)
|
||||||
|
return [n not in exclude for n in names]
|
||||||
|
|
||||||
def prepare_action(self, batch):
|
def prepare_action(self, batch):
|
||||||
"""Pad action"""
|
"""Pad action"""
|
||||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||||
@@ -1261,7 +1273,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
actions = actions[:, :, :original_action_dim]
|
actions = actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
if self.config.use_delta_actions:
|
if self.config.use_delta_actions:
|
||||||
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@@ -1281,7 +1293,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
if self.config.use_delta_actions:
|
if self.config.use_delta_actions:
|
||||||
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class PI05Config(PreTrainedConfig):
|
|||||||
|
|
||||||
# Delta actions: converts absolute actions to delta (relative to state).
|
# Delta actions: converts absolute actions to delta (relative to state).
|
||||||
use_delta_actions: bool = False
|
use_delta_actions: bool = False
|
||||||
|
# Joint names to exclude from delta conversion (kept as absolute).
|
||||||
|
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||||
|
# Populated at runtime by make_policy from dataset metadata.
|
||||||
|
action_feature_names: list[str] | None = None
|
||||||
|
|
||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|||||||
@@ -1201,6 +1201,14 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def _build_delta_mask(self, action_dim: int) -> list[bool]:
|
||||||
|
"""Build a boolean mask for delta action conversion."""
|
||||||
|
names = self.config.action_feature_names
|
||||||
|
if names is None:
|
||||||
|
return [True] * action_dim
|
||||||
|
exclude = set(self.config.delta_exclude_joints)
|
||||||
|
return [n not in exclude for n in names]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
@@ -1236,7 +1244,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if self.config.use_delta_actions:
|
if self.config.use_delta_actions:
|
||||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@@ -1257,7 +1265,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if self.config.use_delta_actions:
|
if self.config.use_delta_actions:
|
||||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||||
|
|||||||
@@ -243,6 +243,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||||
|
|
||||||
|
# Recompute action stats as delta if use_delta_actions is enabled
|
||||||
|
if getattr(cfg.policy, "use_delta_actions", False) and is_main_process:
|
||||||
|
logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)")
|
||||||
|
from lerobot.datasets.dataset_tools import recompute_stats
|
||||||
|
|
||||||
|
exclude = getattr(cfg.policy, "delta_exclude_joints", ["gripper"])
|
||||||
|
recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude)
|
||||||
|
|
||||||
# Wait for all processes to finish policy creation before continuing
|
# Wait for all processes to finish policy creation before continuing
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user