only recompute state for stats

This commit is contained in:
Pepijn
2026-02-20 23:04:20 +01:00
parent a2f5b3571e
commit e1ced538e3
+20 -8
View File
@@ -1559,8 +1559,6 @@ def recompute_stats(
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
# Build delta mask if delta_action is enabled
delta_mask = None
if delta_action and "action" in features and "observation.state" in features:
@@ -1573,7 +1571,11 @@ def recompute_stats(
else:
action_dim = features["action"]["shape"][0]
delta_mask = [True] * action_dim
logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})")
# Only recompute action stats when delta is enabled — state stays unchanged
features_to_compute = {"action": features["action"]}
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
else:
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
@@ -1582,6 +1584,8 @@ def recompute_stats(
all_episode_stats = []
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
# Also need state for delta computation even though we don't recompute state stats
needs_state = delta_mask is not None
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
df = pd.read_parquet(parquet_path)
@@ -1598,12 +1602,19 @@ def recompute_stats(
episode_data[key] = np.array(values)
# Apply delta conversion to actions before computing stats
if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data:
if delta_mask is not None and "action" in episode_data:
from lerobot.processor.delta_action_processor import to_delta_actions
actions_t = torch.from_numpy(episode_data["action"]).float()
states_t = torch.from_numpy(episode_data["observation.state"]).float()
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
# Load state for delta even if we're not computing state stats
if needs_state and "observation.state" in ep_df.columns:
state_values = ep_df["observation.state"].values
if hasattr(state_values[0], "__len__"):
states = np.stack(state_values)
else:
states = np.array(state_values)
actions_t = torch.from_numpy(episode_data["action"]).float()
states_t = torch.from_numpy(states).float()
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
ep_stats = compute_episode_stats(episode_data, features_to_compute)
all_episode_stats.append(ep_stats)
@@ -1614,7 +1625,8 @@ def recompute_stats(
new_stats = aggregate_stats(all_episode_stats)
if skip_image_video and dataset.meta.stats:
# Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in new_stats:
new_stats[key] = value