mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
only recompute state for stats
This commit is contained in:
@@ -1559,8 +1559,6 @@ def recompute_stats(
|
|||||||
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
|
||||||
|
|
||||||
# Build delta mask if delta_action is enabled
|
# Build delta mask if delta_action is enabled
|
||||||
delta_mask = None
|
delta_mask = None
|
||||||
if delta_action and "action" in features and "observation.state" in features:
|
if delta_action and "action" in features and "observation.state" in features:
|
||||||
@@ -1573,7 +1571,11 @@ def recompute_stats(
|
|||||||
else:
|
else:
|
||||||
action_dim = features["action"]["shape"][0]
|
action_dim = features["action"]["shape"][0]
|
||||||
delta_mask = [True] * action_dim
|
delta_mask = [True] * action_dim
|
||||||
logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})")
|
# Only recompute action stats when delta is enabled — state stays unchanged
|
||||||
|
features_to_compute = {"action": features["action"]}
|
||||||
|
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
|
||||||
|
else:
|
||||||
|
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||||
|
|
||||||
data_dir = dataset.root / DATA_DIR
|
data_dir = dataset.root / DATA_DIR
|
||||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||||
@@ -1582,6 +1584,8 @@ def recompute_stats(
|
|||||||
|
|
||||||
all_episode_stats = []
|
all_episode_stats = []
|
||||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||||
|
# Also need state for delta computation even though we don't recompute state stats
|
||||||
|
needs_state = delta_mask is not None
|
||||||
|
|
||||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||||
df = pd.read_parquet(parquet_path)
|
df = pd.read_parquet(parquet_path)
|
||||||
@@ -1598,11 +1602,18 @@ def recompute_stats(
|
|||||||
episode_data[key] = np.array(values)
|
episode_data[key] = np.array(values)
|
||||||
|
|
||||||
# Apply delta conversion to actions before computing stats
|
# Apply delta conversion to actions before computing stats
|
||||||
if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data:
|
if delta_mask is not None and "action" in episode_data:
|
||||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||||
|
|
||||||
|
# Load state for delta even if we're not computing state stats
|
||||||
|
if needs_state and "observation.state" in ep_df.columns:
|
||||||
|
state_values = ep_df["observation.state"].values
|
||||||
|
if hasattr(state_values[0], "__len__"):
|
||||||
|
states = np.stack(state_values)
|
||||||
|
else:
|
||||||
|
states = np.array(state_values)
|
||||||
actions_t = torch.from_numpy(episode_data["action"]).float()
|
actions_t = torch.from_numpy(episode_data["action"]).float()
|
||||||
states_t = torch.from_numpy(episode_data["observation.state"]).float()
|
states_t = torch.from_numpy(states).float()
|
||||||
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
|
||||||
|
|
||||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||||
@@ -1614,7 +1625,8 @@ def recompute_stats(
|
|||||||
|
|
||||||
new_stats = aggregate_stats(all_episode_stats)
|
new_stats = aggregate_stats(all_episode_stats)
|
||||||
|
|
||||||
if skip_image_video and dataset.meta.stats:
|
# Merge: keep existing stats for features we didn't recompute
|
||||||
|
if dataset.meta.stats:
|
||||||
for key, value in dataset.meta.stats.items():
|
for key, value in dataset.meta.stats.items():
|
||||||
if key not in new_stats:
|
if key not in new_stats:
|
||||||
new_stats[key] = value
|
new_stats[key] = value
|
||||||
|
|||||||
Reference in New Issue
Block a user