* add initial modeling

* make rewind pretrained policy

* add annotation

* small fix

* add sarm

* subtasks

* fix spawn

* fix rewind discrepancies

* Add script to generate embedding for dataset (#2138)

* Add generate and validate script

* fix precommit

* Improve generate embeddings function by using dataset tools (#2206)

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>

* cleanup

* change order train log

* print batch size

* update sarm processor

* add reward output

* change expected features

* add image validation

* change validation

* get state input from dataset stats

* raise if no state key is found

* pass stats

* cleanup and refactor

* add episode inddex to complementary data

* add subtask init and detection

* revert lerobot_train changes

* pass dataset metadata to policy

* change loadig subtasks

* add small logging

* fix progress conversion and adding initial frame

* use large offset for initial frame (ugly)

* Remove rewind, use clip tokenizer

* add tests, implement formula 1,2 correctly and cleanup

* use task from dataset, cleanup visualizer

* simplify

* simplify and cleanup code and move compute_temporal_proportions to utils

* fix normalization in visualization

* Fix visualization and change prompt

* fix formatting

* add visualize subtask annotations

* use qwen thinking

* try different prompt

* format

* update prompt

* higher temp, long output

* different settings

* use instruct

* show full resp

* split message

* Temp: increase tolerance dataset

* Fix RA-BC (#2572)

* Add next observation loading for RA-BC progress deltas

* Compute weights based on temporal progress deltas instead of static rewards

* Add hard-masking for negative progress deltas in weight computation

* Feat/add dual head (#2582)

* Add dual dense sparse head and annotation

* Add docs

* add dual to procesor

* cleanup

* change sampling in visualize and cleanup

* remove validation

* remove compile

* Feat/test uniform (#2587)

* test uniform

* add different string for misaligned

* Fix rewind and add tests

* uncomment text implementation

* run precommit

* Add head mode for ra-bc

* fix visalization of single task

* add

* return per sample loss

* Fix RA_BC (#2602)

* update rabc implementation

* compute rabc beforehand

* fix import

* add only progress calulation

* use precomputed progress

* multi gpu processing

* import

* fix dataset meta data extraction

* add logging

* logging

* log

* progress per episode

* split differently

* move clip to gpu

* pre decode frames for an episode

* fix cuda initalization

* fix import

* multi processing

* rename

* fix import

* fix

* fix rabc

* use last known progress if oob

* use last known progress if oob

* add misalignment loss with random embeddings

* discard previous changes

* add selection of models to docs for ra_bc

* add transformers dep

* extend tolerance

* initial commit with new codebase

* add tests

* fix

* remove temporal sampler

* drop last frame for sampler

* use original ref

* some fixes

* fix visualization

* remove smoothing and fix order subtasks

* add stride rabc computation

* add push to hub

* add explanation

* add kappa expllaination

* better rabc logging

* feedback pr

* remove dataset tolerance

* revert dataset tool

* revert dataset changes

* add credit

* run precommit

* change path for generate ra_bc

* fix type

* include sarm in all in pyproject

* fix precommit

* lazy import matplotlib

* lazy import qwen

* remove rich console

* skip if transformers is not installed?

* run only when we have faker

* place transformer lazy loading

* Dont test if low transformer version

* fix

* increase transformer

* increase as 4.57.0 is yanked

* remove pi from all

* go back

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
Pepijn
2025-12-18 12:50:32 +01:00
committed by GitHub
parent 4a151a9682
commit f04958527e
30 changed files with 6449 additions and 29 deletions
+276
View File
@@ -0,0 +1,276 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import torch
class RABCWeights:
"""
Load precomputed SARM progress values and compute RA-BC weights during training.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
During training, computes:
- progress_delta = progress[t + chunk_size] - progress[t]
- rabc_weight based on the delta (paper Eq. 8-9)
Args:
progress_path: Path to parquet file with precomputed progress values
chunk_size: Number of frames ahead for computing progress delta
head_mode: Which SARM head to use ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
device: Device to return tensors on
"""
def __init__(
self,
progress_path: str | Path,
chunk_size: int = 50,
head_mode: str = "sparse",
kappa: float = 0.01,
epsilon: float = 1e-6,
fallback_weight: float = 1.0,
device: torch.device = None,
):
self.progress_path = Path(progress_path)
self.chunk_size = chunk_size
self.head_mode = head_mode
self.kappa = kappa
self.epsilon = epsilon
self.fallback_weight = fallback_weight
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Determine progress column name
self.progress_column = f"progress_{head_mode}"
# Load progress values
logging.info(f"Loading SARM progress values from {self.progress_path}")
self.df = pd.read_parquet(self.progress_path)
# Check if the requested head mode column exists
if self.progress_column not in self.df.columns:
available = [c for c in self.df.columns if c.startswith("progress")]
raise ValueError(
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
)
logging.info(f"Using progress column: {self.progress_column}")
self.progress_lookup = {}
self.episode_lookup = {}
for _, row in self.df.iterrows():
global_idx = int(row["index"])
progress = row[self.progress_column]
episode_idx = int(row["episode_index"])
if not np.isnan(progress):
self.progress_lookup[global_idx] = float(progress)
self.episode_lookup[global_idx] = episode_idx
# Build episode boundaries for delta computation
self.episode_boundaries = {}
for episode_idx in self.df["episode_index"].unique():
ep_df = self.df[self.df["episode_index"] == episode_idx]
self.episode_boundaries[int(episode_idx)] = {
"start": int(ep_df["index"].min()),
"end": int(ep_df["index"].max()) + 1,
}
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
logging.info(f"Chunk size for delta computation: {chunk_size}")
# Compute global statistics for weight computation
self._compute_global_stats()
def _compute_global_stats(self):
"""Compute global mean and std of progress deltas for weight calculation."""
all_deltas = []
for global_idx, progress in self.progress_lookup.items():
episode_idx = self.episode_lookup.get(global_idx)
if episode_idx is None:
continue
bounds = self.episode_boundaries.get(episode_idx)
if bounds is None:
continue
future_idx = global_idx + self.chunk_size
if future_idx >= bounds["end"]:
# Near end of episode: use last frame's progress
future_idx = bounds["end"] - 1
future_progress = self.progress_lookup.get(future_idx)
if future_progress is not None:
delta = future_progress - progress
all_deltas.append(delta)
if all_deltas:
self.delta_mean = max(np.mean(all_deltas), 0.0)
self.delta_std = max(np.std(all_deltas), self.epsilon)
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
else:
self.delta_mean = 0.0
self.delta_std = self.epsilon
logging.warning("No valid progress deltas found, using default stats")
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute RA-BC weights for a batch.
For each sample:
1. Get progress at current frame
2. Get progress at frame + chunk_size (within same episode)
3. Compute delta = future_progress - current_progress
4. Compute weight using paper Eq. 8-9
Args:
batch: Training batch containing "index" key with global frame indices
Returns:
Tuple of:
- Weights tensor (batch_size,) normalized to sum to batch_size
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
"""
indices = batch.get("index")
if indices is None:
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
batch_size = self._get_batch_size(batch)
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
# Convert to list of ints
if isinstance(indices, torch.Tensor):
indices = indices.cpu().numpy().tolist()
elif isinstance(indices, np.ndarray):
indices = indices.tolist()
# Compute deltas and weights for each sample
deltas = []
for idx in indices:
idx = int(idx)
delta = self._compute_delta(idx)
deltas.append(delta)
deltas = np.array(deltas, dtype=np.float32)
# Compute weights from deltas
weights = self._compute_weights(deltas)
# Compute stats before normalization for logging
raw_mean_weight = float(np.nanmean(weights))
num_zero_weight = int(np.sum(weights == 0))
num_full_weight = int(np.sum(weights == 1.0))
batch_stats = {
"raw_mean_weight": raw_mean_weight,
"num_zero_weight": num_zero_weight,
"num_full_weight": num_full_weight,
}
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
# Normalize to sum to batch_size
batch_size = len(weights)
weight_sum = weights.sum() + self.epsilon
weights = weights * batch_size / weight_sum
return weights, batch_stats
def _compute_delta(self, global_idx: int) -> float:
"""Compute progress delta for a single frame."""
current_progress = self.progress_lookup.get(global_idx)
if current_progress is None:
return np.nan
episode_idx = self.episode_lookup.get(global_idx)
if episode_idx is None:
return np.nan
bounds = self.episode_boundaries.get(episode_idx)
if bounds is None:
return np.nan
future_idx = global_idx + self.chunk_size # Δ = chunk_size
if future_idx >= bounds["end"]:
# Near end of episode: use last frame's progress instead
future_idx = bounds["end"] - 1
future_progress = self.progress_lookup.get(future_idx)
if future_progress is None:
return np.nan
return future_progress - current_progress
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
"""
Compute RA-BC weights from progress deltas.
Following paper Eq. 8-9:
- Soft weight: ˜wi = clip((ri 2σ)) / (4σ + ε), 0, 1)
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
Returns:
Array of weights
"""
valid_mask = ~np.isnan(deltas)
# Compute soft weights using global statistics
lower_bound = self.delta_mean - 2 * self.delta_std
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
soft_weights = np.clip(soft_weights, 0.0, 1.0)
# Apply paper's Eq. 9
weights = np.zeros_like(deltas, dtype=np.float32)
# High quality: ri > kappa → weight = 1
high_quality_mask = deltas > self.kappa
weights[high_quality_mask] = 1.0
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
weights[moderate_mask] = soft_weights[moderate_mask]
# Negative progress: ri < 0 → weight = 0 (already 0)
# Invalid (NaN): use fallback weight
weights[~valid_mask] = self.fallback_weight
return weights
def _get_batch_size(self, batch: dict) -> int:
"""Determine batch size from batch."""
for key in ["action", "index"]:
if key in batch:
val = batch[key]
if isinstance(val, (torch.Tensor, np.ndarray)):
return val.shape[0]
return 1
def get_stats(self) -> dict:
"""Get statistics."""
return {
"num_frames": len(self.progress_lookup),
"chunk_size": self.chunk_size,
"head_mode": self.head_mode,
"delta_mean": self.delta_mean,
"delta_std": self.delta_std,
"kappa": self.kappa,
}