mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Add sarm (#2639)
* add initial modeling * make rewind pretrained policy * add annotation * small fix * add sarm * subtasks * fix spawn * fix rewind discrepancies * Add script to generate embedding for dataset (#2138) * Add generate and validate script * fix precommit * Improve generate embeddings function by using dataset tools (#2206) --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> * cleanup * change order train log * print batch size * update sarm processor * add reward output * change expected features * add image validation * change validation * get state input from dataset stats * raise if no state key is found * pass stats * cleanup and refactor * add episode inddex to complementary data * add subtask init and detection * revert lerobot_train changes * pass dataset metadata to policy * change loadig subtasks * add small logging * fix progress conversion and adding initial frame * use large offset for initial frame (ugly) * Remove rewind, use clip tokenizer * add tests, implement formula 1,2 correctly and cleanup * use task from dataset, cleanup visualizer * simplify * simplify and cleanup code and move compute_temporal_proportions to utils * fix normalization in visualization * Fix visualization and change prompt * fix formatting * add visualize subtask annotations * use qwen thinking * try different prompt * format * update prompt * higher temp, long output * different settings * use instruct * show full resp * split message * Temp: increase tolerance dataset * Fix RA-BC (#2572) * Add next observation loading for RA-BC progress deltas * Compute weights based on temporal progress deltas instead of static rewards * Add hard-masking for negative progress deltas in weight computation * Feat/add dual head (#2582) * Add dual dense sparse head and annotation * Add docs * add dual to procesor * cleanup * change sampling in visualize and cleanup * remove validation * remove compile * Feat/test uniform (#2587) * test uniform * add different string for misaligned * Fix rewind and add tests * uncomment text implementation * run precommit * Add head mode for ra-bc * fix visalization of single task * add * return per sample loss * Fix RA_BC (#2602) * update rabc implementation * compute rabc beforehand * fix import * add only progress calulation * use precomputed progress * multi gpu processing * import * fix dataset meta data extraction * add logging * logging * log * progress per episode * split differently * move clip to gpu * pre decode frames for an episode * fix cuda initalization * fix import * multi processing * rename * fix import * fix * fix rabc * use last known progress if oob * use last known progress if oob * add misalignment loss with random embeddings * discard previous changes * add selection of models to docs for ra_bc * add transformers dep * extend tolerance * initial commit with new codebase * add tests * fix * remove temporal sampler * drop last frame for sampler * use original ref * some fixes * fix visualization * remove smoothing and fix order subtasks * add stride rabc computation * add push to hub * add explanation * add kappa expllaination * better rabc logging * feedback pr * remove dataset tolerance * revert dataset tool * revert dataset changes * add credit * run precommit * change path for generate ra_bc * fix type * include sarm in all in pyproject * fix precommit * lazy import matplotlib * lazy import qwen * remove rich console * skip if transformers is not installed? * run only when we have faker * place transformer lazy loading * Dont test if low transformer version * fix * increase transformer * increase as 4.57.0 is yanked * remove pi from all * go back --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
@@ -62,6 +62,7 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -78,6 +79,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -87,9 +89,30 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
@@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
# Create Accelerator if not provided
|
||||
# It will automatically detect if running in distributed mode or single-process mode
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
@@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
cfg.validate()
|
||||
|
||||
# Only log on main process
|
||||
if is_main_process:
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
@@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
@@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
@@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
@@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user