mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 04:37:01 +00:00
fix relative actions: convert before normalization, use global stats
The previous implementation had a double-normalization bug: the preprocessor normalized actions with absolute stats, then convert_to_relative subtracted normalized state (wrong), then the per-timestep normalizer re-normalized. Now the correct flow is: 1. Convert batch to relative on raw data (before preprocessing) 2. Compute global relative stats (mean/std across all timesteps) 3. Hotswap the preprocessor normalizer to use relative stats 4. Preprocessor normalizes relative values correctly This brings loss from ~3000+ down to ~0.5, matching the main branch. Made-with: Cursor
This commit is contained in:
@@ -49,8 +49,7 @@ from lerobot.utils.train_utils import (
|
||||
)
|
||||
from lerobot.utils.relative_actions import (
|
||||
convert_to_relative_actions,
|
||||
compute_relative_action_stats,
|
||||
PerTimestepNormalizer,
|
||||
compute_global_relative_stats,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
@@ -304,42 +303,60 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Compute per-timestep normalizer for relative actions
|
||||
relative_normalizer = None
|
||||
# Compute relative action/state stats and hotswap them into the normalizer
|
||||
raw_state_key = None
|
||||
if cfg.use_relative_actions:
|
||||
from lerobot.processor.normalize_processor import hotswap_stats
|
||||
|
||||
mode = "actions + state" if cfg.use_relative_state else "actions only"
|
||||
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
stats_path = cfg.output_dir / "relative_stats.pt"
|
||||
|
||||
|
||||
reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {}
|
||||
raw_state_key = reverse_rename.get("observation.state", "observation.state")
|
||||
|
||||
if is_main_process:
|
||||
logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"]))
|
||||
|
||||
|
||||
if stats_path.exists():
|
||||
logging.info(f"Loading pre-computed stats from: {stats_path}")
|
||||
logging.info(f"Loading pre-computed relative stats from: {stats_path}")
|
||||
else:
|
||||
logging.info("Computing per-timestep stats (first 1000 batches)...")
|
||||
logging.info("Using fresh dataset to avoid video decoder state issues...")
|
||||
# Create separate dataset instance to avoid corrupting main dataset's video decoders
|
||||
logging.info("Computing global relative stats (first 1000 batches)...")
|
||||
stats_dataset = make_dataset(cfg)
|
||||
temp_loader = torch.utils.data.DataLoader(
|
||||
stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
|
||||
)
|
||||
reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {}
|
||||
raw_state_key = reverse_rename.get("observation.state", "observation.state")
|
||||
mean, std = compute_relative_action_stats(temp_loader, state_key=raw_state_key, num_batches=1000)
|
||||
rel_stats = compute_global_relative_stats(
|
||||
temp_loader, state_key=raw_state_key,
|
||||
convert_state=cfg.use_relative_state, num_batches=1000,
|
||||
)
|
||||
del temp_loader, stats_dataset
|
||||
gc.collect()
|
||||
torch.save({"mean": mean, "std": std}, stats_path)
|
||||
logging.info(f"Saved stats to: {stats_path}")
|
||||
|
||||
# Poll for stats file instead of using NCCL barrier (avoids timeout during long computation)
|
||||
torch.save(rel_stats, stats_path)
|
||||
logging.info(f"Saved relative stats to: {stats_path}")
|
||||
|
||||
if not is_main_process:
|
||||
while not stats_path.exists():
|
||||
time.sleep(5)
|
||||
|
||||
data = torch.load(stats_path, weights_only=True, map_location="cpu")
|
||||
relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"])
|
||||
accelerator.wait_for_everyone() # Sync after everyone has loaded
|
||||
|
||||
rel_stats = torch.load(stats_path, weights_only=True, map_location="cpu")
|
||||
|
||||
# Replace absolute stats with relative stats in the normalizer
|
||||
updated_stats = dict(dataset.meta.stats)
|
||||
updated_stats["action"] = {
|
||||
**updated_stats["action"],
|
||||
"mean": rel_stats["action_mean"].numpy(),
|
||||
"std": rel_stats["action_std"].numpy(),
|
||||
}
|
||||
if cfg.use_relative_state and "state_mean" in rel_stats:
|
||||
updated_stats[raw_state_key] = {
|
||||
**updated_stats.get(raw_state_key, {}),
|
||||
"mean": rel_stats["state_mean"].numpy(),
|
||||
"std": rel_stats["state_std"].numpy(),
|
||||
}
|
||||
preprocessor = hotswap_stats(preprocessor, updated_stats)
|
||||
logging.info("Hotswapped normalizer stats with relative stats")
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
@@ -427,13 +444,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Convert to relative actions (and optionally state) if enabled
|
||||
|
||||
# Convert to relative on raw data BEFORE normalization
|
||||
if cfg.use_relative_actions:
|
||||
batch = convert_to_relative_actions(batch, convert_state=cfg.use_relative_state)
|
||||
if relative_normalizer is not None:
|
||||
batch["action"] = relative_normalizer.normalize(batch["action"])
|
||||
batch = convert_to_relative_actions(
|
||||
batch, state_key=raw_state_key, convert_state=cfg.use_relative_state,
|
||||
)
|
||||
|
||||
batch = preprocessor(batch)
|
||||
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
@@ -489,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
# Save relative action stats with checkpoint
|
||||
if relative_normalizer is not None:
|
||||
relative_normalizer.save(checkpoint_dir / "relative_stats.pt")
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -74,11 +74,53 @@ def compute_relative_action_stats(
|
||||
rel = action.clone()
|
||||
rel[..., :min_dim] -= current_pos[:, None, :min_dim]
|
||||
all_rel.append(rel)
|
||||
|
||||
|
||||
all_rel = torch.cat(all_rel, dim=0)
|
||||
return all_rel.mean(dim=0), all_rel.std(dim=0).clamp(min=1e-6)
|
||||
|
||||
|
||||
def compute_global_relative_stats(
|
||||
dataloader,
|
||||
state_key: str = "observation.state",
|
||||
convert_state: bool = True,
|
||||
num_batches: int | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Compute global mean/std for relative actions (and state) across all timesteps.
|
||||
|
||||
Returns stats compatible with the standard MEAN_STD normalizer (shape = action_dim).
|
||||
"""
|
||||
all_rel_actions = []
|
||||
all_rel_states = []
|
||||
for i, batch in enumerate(dataloader):
|
||||
if num_batches is not None and i >= num_batches:
|
||||
break
|
||||
action, state = batch["action"], batch[state_key]
|
||||
current_pos = state[:, -1, :] if state.dim() == 3 else state
|
||||
|
||||
min_dim = min(action.shape[-1], current_pos.shape[-1])
|
||||
rel = action.clone()
|
||||
rel[..., :min_dim] -= current_pos[:, None, :min_dim]
|
||||
all_rel_actions.append(rel.reshape(-1, rel.shape[-1]))
|
||||
|
||||
if convert_state:
|
||||
if state.dim() == 3:
|
||||
rel_state = state - current_pos[:, None, :]
|
||||
else:
|
||||
rel_state = torch.zeros_like(state)
|
||||
all_rel_states.append(rel_state.reshape(-1, rel_state.shape[-1]))
|
||||
|
||||
all_rel_actions = torch.cat(all_rel_actions, dim=0)
|
||||
result = {
|
||||
"action_mean": all_rel_actions.mean(dim=0),
|
||||
"action_std": all_rel_actions.std(dim=0).clamp(min=1e-6),
|
||||
}
|
||||
if convert_state and all_rel_states:
|
||||
all_rel_states = torch.cat(all_rel_states, dim=0)
|
||||
result["state_mean"] = all_rel_states.mean(dim=0)
|
||||
result["state_std"] = all_rel_states.std(dim=0).clamp(min=1e-6)
|
||||
return result
|
||||
|
||||
|
||||
def convert_to_relative(
|
||||
batch: dict,
|
||||
state_key: str = "observation.state",
|
||||
|
||||
Reference in New Issue
Block a user