Merge branch 'feat/add_relative_action_pi_models' into feat/mirror

This commit is contained in:
Pepijn
2026-02-21 17:12:52 +01:00
2 changed files with 17 additions and 11 deletions
+3 -9
View File
@@ -331,11 +331,9 @@ class _NormalizationMixin:
)
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
if inverse:
return tensor * std + mean
return (tensor - mean) / denom
return tensor * (std + 1e-6) + mean
return (tensor - mean) / (std + 1e-6)
if norm_mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min", None)
@@ -367,11 +365,7 @@ class _NormalizationMixin:
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
)
denom = q99 - q01
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
)
denom = q99 - q01 + 1e-6
if inverse:
return (tensor + 1.0) * denom / 2.0 + q01
return 2.0 * (tensor - q01) / denom - 1.0
+14 -2
View File
@@ -284,10 +284,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
all_delta = np.concatenate(all_delta_actions, axis=0)
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
dataset.meta.stats["action"] = delta_stats
# Determine normalization type for logging
norm_type = "UNKNOWN"
if hasattr(cfg.policy, "normalization_mapping"):
from lerobot.configs.types import NormalizationMode
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
norm_type = action_norm.value if action_norm else "UNKNOWN"
logging.info(
f"Delta action stats: mean={np.abs(delta_stats['mean']).mean():.4f}, "
f"std={delta_stats['std'].mean():.4f}"
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
)
if norm_type == "QUANTILES":
q_range = (delta_stats['q99'] - delta_stats['q01']).mean()
logging.info(f" Quantile range (q99-q01): {q_range:.4f}")
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()