fix relative actions: convert before normalization, use global stats

The previous implementation had a double-normalization bug: the
preprocessor normalized actions with absolute stats, then
convert_to_relative subtracted normalized state (wrong), then the
per-timestep normalizer re-normalized.

Now the correct flow is:
1. Convert batch to relative on raw data (before preprocessing)
2. Compute global relative stats (mean/std across all timesteps)
3. Hotswap the preprocessor normalizer to use relative stats
4. Preprocessor normalizes relative values correctly

This brings loss from ~3000+ down to ~0.5, matching the main branch.

Made-with: Cursor
This commit is contained in:
pepijn
2026-04-01 20:45:35 +00:00
parent 900f1a42e9
commit 936187cd76
2 changed files with 88 additions and 31 deletions
+45 -30
View File
@@ -49,8 +49,7 @@ from lerobot.utils.train_utils import (
)
from lerobot.utils.relative_actions import (
convert_to_relative_actions,
compute_relative_action_stats,
PerTimestepNormalizer,
compute_global_relative_stats,
)
from lerobot.utils.utils import (
format_big_number,
@@ -304,42 +303,60 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
device=device,
)
# Compute per-timestep normalizer for relative actions
relative_normalizer = None
# Compute relative action/state stats and hotswap them into the normalizer
raw_state_key = None
if cfg.use_relative_actions:
from lerobot.processor.normalize_processor import hotswap_stats
mode = "actions + state" if cfg.use_relative_state else "actions only"
cfg.output_dir.mkdir(parents=True, exist_ok=True)
stats_path = cfg.output_dir / "relative_stats.pt"
reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {}
raw_state_key = reverse_rename.get("observation.state", "observation.state")
if is_main_process:
logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"]))
if stats_path.exists():
logging.info(f"Loading pre-computed stats from: {stats_path}")
logging.info(f"Loading pre-computed relative stats from: {stats_path}")
else:
logging.info("Computing per-timestep stats (first 1000 batches)...")
logging.info("Using fresh dataset to avoid video decoder state issues...")
# Create separate dataset instance to avoid corrupting main dataset's video decoders
logging.info("Computing global relative stats (first 1000 batches)...")
stats_dataset = make_dataset(cfg)
temp_loader = torch.utils.data.DataLoader(
stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
)
reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {}
raw_state_key = reverse_rename.get("observation.state", "observation.state")
mean, std = compute_relative_action_stats(temp_loader, state_key=raw_state_key, num_batches=1000)
rel_stats = compute_global_relative_stats(
temp_loader, state_key=raw_state_key,
convert_state=cfg.use_relative_state, num_batches=1000,
)
del temp_loader, stats_dataset
gc.collect()
torch.save({"mean": mean, "std": std}, stats_path)
logging.info(f"Saved stats to: {stats_path}")
# Poll for stats file instead of using NCCL barrier (avoids timeout during long computation)
torch.save(rel_stats, stats_path)
logging.info(f"Saved relative stats to: {stats_path}")
if not is_main_process:
while not stats_path.exists():
time.sleep(5)
data = torch.load(stats_path, weights_only=True, map_location="cpu")
relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"])
accelerator.wait_for_everyone() # Sync after everyone has loaded
rel_stats = torch.load(stats_path, weights_only=True, map_location="cpu")
# Replace absolute stats with relative stats in the normalizer
updated_stats = dict(dataset.meta.stats)
updated_stats["action"] = {
**updated_stats["action"],
"mean": rel_stats["action_mean"].numpy(),
"std": rel_stats["action_std"].numpy(),
}
if cfg.use_relative_state and "state_mean" in rel_stats:
updated_stats[raw_state_key] = {
**updated_stats.get(raw_state_key, {}),
"mean": rel_stats["state_mean"].numpy(),
"std": rel_stats["state_std"].numpy(),
}
preprocessor = hotswap_stats(preprocessor, updated_stats)
logging.info("Hotswapped normalizer stats with relative stats")
accelerator.wait_for_everyone()
step = 0 # number of policy updates (forward + backward + optim)
@@ -427,13 +444,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
batch = preprocessor(batch)
# Convert to relative actions (and optionally state) if enabled
# Convert to relative on raw data BEFORE normalization
if cfg.use_relative_actions:
batch = convert_to_relative_actions(batch, convert_state=cfg.use_relative_state)
if relative_normalizer is not None:
batch["action"] = relative_normalizer.normalize(batch["action"])
batch = convert_to_relative_actions(
batch, state_key=raw_state_key, convert_state=cfg.use_relative_state,
)
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
@@ -489,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
preprocessor=preprocessor,
postprocessor=postprocessor,
)
# Save relative action stats with checkpoint
if relative_normalizer is not None:
relative_normalizer.save(checkpoint_dir / "relative_stats.pt")
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
+43 -1
View File
@@ -74,11 +74,53 @@ def compute_relative_action_stats(
rel = action.clone()
rel[..., :min_dim] -= current_pos[:, None, :min_dim]
all_rel.append(rel)
all_rel = torch.cat(all_rel, dim=0)
return all_rel.mean(dim=0), all_rel.std(dim=0).clamp(min=1e-6)
def compute_global_relative_stats(
dataloader,
state_key: str = "observation.state",
convert_state: bool = True,
num_batches: int | None = None,
) -> dict[str, torch.Tensor]:
"""Compute global mean/std for relative actions (and state) across all timesteps.
Returns stats compatible with the standard MEAN_STD normalizer (shape = action_dim).
"""
all_rel_actions = []
all_rel_states = []
for i, batch in enumerate(dataloader):
if num_batches is not None and i >= num_batches:
break
action, state = batch["action"], batch[state_key]
current_pos = state[:, -1, :] if state.dim() == 3 else state
min_dim = min(action.shape[-1], current_pos.shape[-1])
rel = action.clone()
rel[..., :min_dim] -= current_pos[:, None, :min_dim]
all_rel_actions.append(rel.reshape(-1, rel.shape[-1]))
if convert_state:
if state.dim() == 3:
rel_state = state - current_pos[:, None, :]
else:
rel_state = torch.zeros_like(state)
all_rel_states.append(rel_state.reshape(-1, rel_state.shape[-1]))
all_rel_actions = torch.cat(all_rel_actions, dim=0)
result = {
"action_mean": all_rel_actions.mean(dim=0),
"action_std": all_rel_actions.std(dim=0).clamp(min=1e-6),
}
if convert_state and all_rel_states:
all_rel_states = torch.cat(all_rel_states, dim=0)
result["state_mean"] = all_rel_states.mean(dim=0)
result["state_std"] = all_rel_states.std(dim=0).clamp(min=1e-6)
return result
def convert_to_relative(
batch: dict,
state_key: str = "observation.state",