mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
only recompute state for stats
This commit is contained in:
@@ -1559,8 +1559,6 @@ def recompute_stats(
|
||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||
}
|
||||
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
# Build delta mask if delta_action is enabled
|
||||
delta_mask = None
|
||||
if delta_action and "action" in features and "observation.state" in features:
|
||||
@@ -1573,7 +1571,11 @@ def recompute_stats(
|
||||
else:
|
||||
action_dim = features["action"]["shape"][0]
|
||||
delta_mask = [True] * action_dim
|
||||
logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})")
|
||||
# Only recompute action stats when delta is enabled — state stays unchanged
|
||||
features_to_compute = {"action": features["action"]}
|
||||
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
|
||||
else:
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
@@ -1582,6 +1584,8 @@ def recompute_stats(
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
# Also need state for delta computation even though we don't recompute state stats
|
||||
needs_state = delta_mask is not None
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
@@ -1598,12 +1602,19 @@ def recompute_stats(
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
# Apply delta conversion to actions before computing stats
|
||||
if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data:
|
||||
if delta_mask is not None and "action" in episode_data:
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(episode_data["observation.state"]).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
# Load state for delta even if we're not computing state stats
|
||||
if needs_state and "observation.state" in ep_df.columns:
|
||||
state_values = ep_df["observation.state"].values
|
||||
if hasattr(state_values[0], "__len__"):
|
||||
states = np.stack(state_values)
|
||||
else:
|
||||
states = np.array(state_values)
|
||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||
states_t = torch.from_numpy(states).float()
|
||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
@@ -1614,7 +1625,8 @@ def recompute_stats(
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats)
|
||||
|
||||
if skip_image_video and dataset.meta.stats:
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
Reference in New Issue
Block a user