From e1ced538e323660de31ed882934e203ec42ad6d9 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 20 Feb 2026 23:04:20 +0100 Subject: [PATCH] only recompute state for stats --- src/lerobot/datasets/dataset_tools.py | 28 +++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 561448a02..2e405cc42 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -1559,8 +1559,6 @@ def recompute_stats( and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"] } - logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}") - # Build delta mask if delta_action is enabled delta_mask = None if delta_action and "action" in features and "observation.state" in features: @@ -1573,7 +1571,11 @@ def recompute_stats( else: action_dim = features["action"]["shape"][0] delta_mask = [True] * action_dim - logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})") + # Only recompute action stats when delta is enabled — state stays unchanged + features_to_compute = {"action": features["action"]} + logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})") + else: + logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}") data_dir = dataset.root / DATA_DIR parquet_files = sorted(data_dir.glob("*/*.parquet")) @@ -1582,6 +1584,8 @@ def recompute_stats( all_episode_stats = [] numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]] + # Also need state for delta computation even though we don't recompute state stats + needs_state = delta_mask is not None for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"): df = pd.read_parquet(parquet_path) @@ -1598,12 +1602,19 @@ def recompute_stats( episode_data[key] = np.array(values) # Apply delta conversion to actions before computing stats - if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data: + if delta_mask is not None and "action" in episode_data: from lerobot.processor.delta_action_processor import to_delta_actions - actions_t = torch.from_numpy(episode_data["action"]).float() - states_t = torch.from_numpy(episode_data["observation.state"]).float() - episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy() + # Load state for delta even if we're not computing state stats + if needs_state and "observation.state" in ep_df.columns: + state_values = ep_df["observation.state"].values + if hasattr(state_values[0], "__len__"): + states = np.stack(state_values) + else: + states = np.array(state_values) + actions_t = torch.from_numpy(episode_data["action"]).float() + states_t = torch.from_numpy(states).float() + episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy() ep_stats = compute_episode_stats(episode_data, features_to_compute) all_episode_stats.append(ep_stats) @@ -1614,7 +1625,8 @@ def recompute_stats( new_stats = aggregate_stats(all_episode_stats) - if skip_image_video and dataset.meta.stats: + # Merge: keep existing stats for features we didn't recompute + if dataset.meta.stats: for key, value in dataset.meta.stats.items(): if key not in new_stats: new_stats[key] = value