diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 123d455c6..561448a02 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -37,7 +37,7 @@ import torch from tqdm import tqdm from lerobot.datasets.aggregate import aggregate_datasets -from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.utils import ( DATA_DIR, @@ -1522,6 +1522,110 @@ def modify_tasks( return dataset +def recompute_stats( + dataset: LeRobotDataset, + skip_image_video: bool = True, + delta_action: bool = False, + delta_exclude_joints: list[str] | None = None, +) -> LeRobotDataset: + """Recompute stats.json from scratch by iterating all episodes. + + Args: + dataset: The LeRobotDataset to recompute stats for. + skip_image_video: If True (default), only recompute stats for numeric features + (action, state, etc.) and keep existing image/video stats unchanged. + delta_action: If True, compute action stats as delta (action - state). + Useful when training with use_delta_actions=True so normalization matches. + delta_exclude_joints: Joint names to exclude from delta conversion when + delta_action=True. These dims keep absolute stats. Uses dataset's + action feature names to build the mask. Default: ["gripper"]. + + Returns: + The same dataset with updated stats. + """ + features = dataset.meta.features + numeric_features = { + k: v for k, v in features.items() + if v["dtype"] not in ["image", "video", "string"] + and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"] + } + + if skip_image_video: + features_to_compute = numeric_features + else: + features_to_compute = { + k: v for k, v in features.items() + if v["dtype"] != "string" + and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"] + } + + logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}") + + # Build delta mask if delta_action is enabled + delta_mask = None + if delta_action and "action" in features and "observation.state" in features: + if delta_exclude_joints is None: + delta_exclude_joints = ["gripper"] + action_names = features["action"].get("names") + if action_names is not None: + exclude = set(delta_exclude_joints) + delta_mask = [n not in exclude for n in action_names] + else: + action_dim = features["action"]["shape"][0] + delta_mask = [True] * action_dim + logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})") + + data_dir = dataset.root / DATA_DIR + parquet_files = sorted(data_dir.glob("*/*.parquet")) + if not parquet_files: + raise ValueError(f"No parquet files found in {data_dir}") + + all_episode_stats = [] + numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]] + + for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"): + df = pd.read_parquet(parquet_path) + + for ep_idx in sorted(df["episode_index"].unique()): + ep_df = df[df["episode_index"] == ep_idx] + episode_data = {} + for key in numeric_keys: + if key in ep_df.columns: + values = ep_df[key].values + if hasattr(values[0], "__len__"): + episode_data[key] = np.stack(values) + else: + episode_data[key] = np.array(values) + + # Apply delta conversion to actions before computing stats + if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data: + from lerobot.processor.delta_action_processor import to_delta_actions + + actions_t = torch.from_numpy(episode_data["action"]).float() + states_t = torch.from_numpy(episode_data["observation.state"]).float() + episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy() + + ep_stats = compute_episode_stats(episode_data, features_to_compute) + all_episode_stats.append(ep_stats) + + if not all_episode_stats: + logging.warning("No episode stats computed") + return dataset + + new_stats = aggregate_stats(all_episode_stats) + + if skip_image_video and dataset.meta.stats: + for key, value in dataset.meta.stats.items(): + if key not in new_stats: + new_stats[key] = value + + write_stats(new_stats, dataset.root) + dataset.meta.stats = new_stats + + logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes") + return dataset + + def convert_image_to_video_dataset( dataset: LeRobotDataset, output_dir: Path, diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a593e5bcb..73588e11f 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -470,6 +470,13 @@ def make_policy( cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} if not cfg.input_features: cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} + + # Store action feature names on config for delta_exclude_joints support + if ds_meta is not None and hasattr(cfg, "action_feature_names"): + action_names = ds_meta.features.get(ACTION, {}).get("names") + if action_names is not None: + cfg.action_feature_names = list(action_names) + kwargs["config"] = cfg # Pass dataset_stats to the policy if available (needed for some policies like SARM) diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index 723f0b115..0e10fe196 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -50,8 +50,12 @@ class PI0Config(PreTrainedConfig): min_period: float = 4e-3 max_period: float = 4.0 - # Delta actions: converts absolute actions to delta (relative to state) + # Delta actions: converts absolute actions to delta (relative to state). use_delta_actions: bool = False + # Joint names to exclude from delta conversion (kept as absolute). + delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"]) + # Populated at runtime by make_policy from dataset metadata. + action_feature_names: list[str] | None = None # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 5f82591e2..c19c788f5 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -1221,6 +1221,18 @@ class PI0Policy(PreTrainedPolicy): state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) return state + def _build_delta_mask(self, action_dim: int) -> list[bool]: + """Build a boolean mask for delta action conversion. + + Uses action_feature_names and delta_exclude_joints to determine which + dims get delta conversion. Falls back to all-True if names are unavailable. + """ + names = self.config.action_feature_names + if names is None: + return [True] * action_dim + exclude = set(self.config.delta_exclude_joints) + return [n not in exclude for n in names] + def prepare_action(self, batch): """Pad action""" actions = pad_vector(batch[ACTION], self.config.max_action_dim) @@ -1261,7 +1273,7 @@ class PI0Policy(PreTrainedPolicy): actions = actions[:, :, :original_action_dim] if self.config.use_delta_actions: - actions = to_absolute_actions(actions, state, [True] * actions.shape[-1]) + actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1])) return actions @@ -1281,7 +1293,7 @@ class PI0Policy(PreTrainedPolicy): actions = self.prepare_action(batch) if self.config.use_delta_actions: - actions = to_delta_actions(actions, state, [True] * actions.shape[-1]) + actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1])) # Compute loss losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index da0e29137..5752a54b0 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -52,6 +52,10 @@ class PI05Config(PreTrainedConfig): # Delta actions: converts absolute actions to delta (relative to state). use_delta_actions: bool = False + # Joint names to exclude from delta conversion (kept as absolute). + delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"]) + # Populated at runtime by make_policy from dataset metadata. + action_feature_names: list[str] | None = None # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index fca4405fa..d70228275 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1201,6 +1201,14 @@ class PI05Policy(PreTrainedPolicy): actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions + def _build_delta_mask(self, action_dim: int) -> list[bool]: + """Build a boolean mask for delta action conversion.""" + names = self.config.action_feature_names + if names is None: + return [True] * action_dim + exclude = set(self.config.delta_exclude_joints) + return [n not in exclude for n in names] + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" @@ -1236,7 +1244,7 @@ class PI05Policy(PreTrainedPolicy): if self.config.use_delta_actions: state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) - actions = to_absolute_actions(actions, state, [True] * actions.shape[-1]) + actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1])) return actions @@ -1257,7 +1265,7 @@ class PI05Policy(PreTrainedPolicy): if self.config.use_delta_actions: state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) - actions = to_delta_actions(actions, state, [True] * actions.shape[-1]) + actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1])) # Compute loss (no separate state needed for PI05) losses = self.model.forward(images, img_masks, tokens, masks, actions) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 93b99e245..8bfbf224f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -243,6 +243,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): peft_cli_overrides = dataclasses.asdict(cfg.peft) policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) + # Recompute action stats as delta if use_delta_actions is enabled + if getattr(cfg.policy, "use_delta_actions", False) and is_main_process: + logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)") + from lerobot.datasets.dataset_tools import recompute_stats + + exclude = getattr(cfg.policy, "delta_exclude_joints", ["gripper"]) + recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude) + # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone()