mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Merge branch 'feat/add_relative_action_pi_models' into feat/mirror
This commit is contained in:
@@ -243,13 +243,35 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Recompute action stats as delta if use_delta_actions is enabled
|
||||
# Recompute action stats as delta if use_delta_actions is enabled.
|
||||
# Must iterate the actual dataset (which returns action chunks via delta_timestamps)
|
||||
# so stats capture the full range of chunk-level deltas, not just per-frame deltas.
|
||||
if getattr(cfg.policy, "use_delta_actions", False) and is_main_process:
|
||||
logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)")
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
logging.info("use_delta_actions is enabled — computing delta action stats from dataset chunks")
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
exclude = getattr(cfg.policy, "delta_exclude_joints", [])
|
||||
recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude)
|
||||
all_delta_actions = []
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
action = item["action"]
|
||||
state = item["observation.state"]
|
||||
# action may be (chunk_size, action_dim) or (action_dim,)
|
||||
if action.ndim == 1:
|
||||
action = action.unsqueeze(0)
|
||||
mask = [True] * action.shape[-1]
|
||||
delta = to_delta_actions(action.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_delta_actions.append(delta.numpy())
|
||||
|
||||
import numpy as np
|
||||
|
||||
all_delta = np.concatenate(all_delta_actions, axis=0)
|
||||
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
|
||||
dataset.meta.stats["action"] = delta_stats
|
||||
logging.info(
|
||||
f"Delta action stats computed from {len(dataset)} samples: "
|
||||
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}"
|
||||
)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user