mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-29 15:39:56 +00:00
Add sarm (#2639)
* add initial modeling * make rewind pretrained policy * add annotation * small fix * add sarm * subtasks * fix spawn * fix rewind discrepancies * Add script to generate embedding for dataset (#2138) * Add generate and validate script * fix precommit * Improve generate embeddings function by using dataset tools (#2206) --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> * cleanup * change order train log * print batch size * update sarm processor * add reward output * change expected features * add image validation * change validation * get state input from dataset stats * raise if no state key is found * pass stats * cleanup and refactor * add episode inddex to complementary data * add subtask init and detection * revert lerobot_train changes * pass dataset metadata to policy * change loadig subtasks * add small logging * fix progress conversion and adding initial frame * use large offset for initial frame (ugly) * Remove rewind, use clip tokenizer * add tests, implement formula 1,2 correctly and cleanup * use task from dataset, cleanup visualizer * simplify * simplify and cleanup code and move compute_temporal_proportions to utils * fix normalization in visualization * Fix visualization and change prompt * fix formatting * add visualize subtask annotations * use qwen thinking * try different prompt * format * update prompt * higher temp, long output * different settings * use instruct * show full resp * split message * Temp: increase tolerance dataset * Fix RA-BC (#2572) * Add next observation loading for RA-BC progress deltas * Compute weights based on temporal progress deltas instead of static rewards * Add hard-masking for negative progress deltas in weight computation * Feat/add dual head (#2582) * Add dual dense sparse head and annotation * Add docs * add dual to procesor * cleanup * change sampling in visualize and cleanup * remove validation * remove compile * Feat/test uniform (#2587) * test uniform * add different string for misaligned * Fix rewind and add tests * uncomment text implementation * run precommit * Add head mode for ra-bc * fix visalization of single task * add * return per sample loss * Fix RA_BC (#2602) * update rabc implementation * compute rabc beforehand * fix import * add only progress calulation * use precomputed progress * multi gpu processing * import * fix dataset meta data extraction * add logging * logging * log * progress per episode * split differently * move clip to gpu * pre decode frames for an episode * fix cuda initalization * fix import * multi processing * rename * fix import * fix * fix rabc * use last known progress if oob * use last known progress if oob * add misalignment loss with random embeddings * discard previous changes * add selection of models to docs for ra_bc * add transformers dep * extend tolerance * initial commit with new codebase * add tests * fix * remove temporal sampler * drop last frame for sampler * use original ref * some fixes * fix visualization * remove smoothing and fix order subtasks * add stride rabc computation * add push to hub * add explanation * add kappa expllaination * better rabc logging * feedback pr * remove dataset tolerance * revert dataset tool * revert dataset changes * add credit * run precommit * change path for generate ra_bc * fix type * include sarm in all in pyproject * fix precommit * lazy import matplotlib * lazy import qwen * remove rich console * skip if transformers is not installed? * run only when we have faker * place transformer lazy loading * Dont test if low transformer version * fix * increase transformer * increase as 4.57.0 is yanked * remove pi from all * go back --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values
|
||||
chunk_size: Number of frames ahead for computing progress delta
|
||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||
device: Device to return tensors on
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
progress_path: str | Path,
|
||||
chunk_size: int = 50,
|
||||
head_mode: str = "sparse",
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device = None,
|
||||
):
|
||||
self.progress_path = Path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
self.head_mode = head_mode
|
||||
self.kappa = kappa
|
||||
self.epsilon = epsilon
|
||||
self.fallback_weight = fallback_weight
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Determine progress column name
|
||||
self.progress_column = f"progress_{head_mode}"
|
||||
|
||||
# Load progress values
|
||||
logging.info(f"Loading SARM progress values from {self.progress_path}")
|
||||
self.df = pd.read_parquet(self.progress_path)
|
||||
|
||||
# Check if the requested head mode column exists
|
||||
if self.progress_column not in self.df.columns:
|
||||
available = [c for c in self.df.columns if c.startswith("progress")]
|
||||
raise ValueError(
|
||||
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
|
||||
)
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup = {}
|
||||
self.episode_lookup = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
progress = row[self.progress_column]
|
||||
episode_idx = int(row["episode_index"])
|
||||
|
||||
if not np.isnan(progress):
|
||||
self.progress_lookup[global_idx] = float(progress)
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
"start": int(ep_df["index"].min()),
|
||||
"end": int(ep_df["index"].max()) + 1,
|
||||
}
|
||||
|
||||
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
|
||||
logging.info(f"Chunk size for delta computation: {chunk_size}")
|
||||
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self):
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
for global_idx, progress in self.progress_lookup.items():
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
continue
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
continue
|
||||
|
||||
future_idx = global_idx + self.chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is not None:
|
||||
delta = future_progress - progress
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
self.delta_std = self.epsilon
|
||||
logging.warning("No valid progress deltas found, using default stats")
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute RA-BC weights for a batch.
|
||||
|
||||
For each sample:
|
||||
1. Get progress at current frame
|
||||
2. Get progress at frame + chunk_size (within same episode)
|
||||
3. Compute delta = future_progress - current_progress
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
indices = indices.cpu().numpy().tolist()
|
||||
elif isinstance(indices, np.ndarray):
|
||||
indices = indices.tolist()
|
||||
|
||||
# Compute deltas and weights for each sample
|
||||
deltas = []
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"raw_mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
|
||||
return weights, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
current_progress = self.progress_lookup.get(global_idx)
|
||||
if current_progress is None:
|
||||
return np.nan
|
||||
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
return np.nan
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
return np.nan
|
||||
|
||||
future_idx = global_idx + self.chunk_size # Δ = chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress instead
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is None:
|
||||
return np.nan
|
||||
|
||||
return future_progress - current_progress
|
||||
|
||||
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute RA-BC weights from progress deltas.
|
||||
|
||||
Following paper Eq. 8-9:
|
||||
- Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
# Compute soft weights using global statistics
|
||||
lower_bound = self.delta_mean - 2 * self.delta_std
|
||||
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
|
||||
soft_weights = np.clip(soft_weights, 0.0, 1.0)
|
||||
|
||||
# Apply paper's Eq. 9
|
||||
weights = np.zeros_like(deltas, dtype=np.float32)
|
||||
|
||||
# High quality: ri > kappa → weight = 1
|
||||
high_quality_mask = deltas > self.kappa
|
||||
weights[high_quality_mask] = 1.0
|
||||
|
||||
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
|
||||
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
|
||||
weights[moderate_mask] = soft_weights[moderate_mask]
|
||||
|
||||
# Negative progress: ri < 0 → weight = 0 (already 0)
|
||||
# Invalid (NaN): use fallback weight
|
||||
weights[~valid_mask] = self.fallback_weight
|
||||
|
||||
return weights
|
||||
|
||||
def _get_batch_size(self, batch: dict) -> int:
|
||||
"""Determine batch size from batch."""
|
||||
for key in ["action", "index"]:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return val.shape[0]
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics."""
|
||||
return {
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
"delta_mean": self.delta_mean,
|
||||
"delta_std": self.delta_std,
|
||||
"kappa": self.kappa,
|
||||
}
|
||||
Reference in New Issue
Block a user