* add initial modeling

* make rewind pretrained policy

* add annotation

* small fix

* add sarm

* subtasks

* fix spawn

* fix rewind discrepancies

* Add script to generate embedding for dataset (#2138)

* Add generate and validate script

* fix precommit

* Improve generate embeddings function by using dataset tools (#2206)

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>

* cleanup

* change order train log

* print batch size

* update sarm processor

* add reward output

* change expected features

* add image validation

* change validation

* get state input from dataset stats

* raise if no state key is found

* pass stats

* cleanup and refactor

* add episode inddex to complementary data

* add subtask init and detection

* revert lerobot_train changes

* pass dataset metadata to policy

* change loadig subtasks

* add small logging

* fix progress conversion and adding initial frame

* use large offset for initial frame (ugly)

* Remove rewind, use clip tokenizer

* add tests, implement formula 1,2 correctly and cleanup

* use task from dataset, cleanup visualizer

* simplify

* simplify and cleanup code and move compute_temporal_proportions to utils

* fix normalization in visualization

* Fix visualization and change prompt

* fix formatting

* add visualize subtask annotations

* use qwen thinking

* try different prompt

* format

* update prompt

* higher temp, long output

* different settings

* use instruct

* show full resp

* split message

* Temp: increase tolerance dataset

* Fix RA-BC (#2572)

* Add next observation loading for RA-BC progress deltas

* Compute weights based on temporal progress deltas instead of static rewards

* Add hard-masking for negative progress deltas in weight computation

* Feat/add dual head (#2582)

* Add dual dense sparse head and annotation

* Add docs

* add dual to procesor

* cleanup

* change sampling in visualize and cleanup

* remove validation

* remove compile

* Feat/test uniform (#2587)

* test uniform

* add different string for misaligned

* Fix rewind and add tests

* uncomment text implementation

* run precommit

* Add head mode for ra-bc

* fix visalization of single task

* add

* return per sample loss

* Fix RA_BC (#2602)

* update rabc implementation

* compute rabc beforehand

* fix import

* add only progress calulation

* use precomputed progress

* multi gpu processing

* import

* fix dataset meta data extraction

* add logging

* logging

* log

* progress per episode

* split differently

* move clip to gpu

* pre decode frames for an episode

* fix cuda initalization

* fix import

* multi processing

* rename

* fix import

* fix

* fix rabc

* use last known progress if oob

* use last known progress if oob

* add misalignment loss with random embeddings

* discard previous changes

* add selection of models to docs for ra_bc

* add transformers dep

* extend tolerance

* initial commit with new codebase

* add tests

* fix

* remove temporal sampler

* drop last frame for sampler

* use original ref

* some fixes

* fix visualization

* remove smoothing and fix order subtasks

* add stride rabc computation

* add push to hub

* add explanation

* add kappa expllaination

* better rabc logging

* feedback pr

* remove dataset tolerance

* revert dataset tool

* revert dataset changes

* add credit

* run precommit

* change path for generate ra_bc

* fix type

* include sarm in all in pyproject

* fix precommit

* lazy import matplotlib

* lazy import qwen

* remove rich console

* skip if transformers is not installed?

* run only when we have faker

* place transformer lazy loading

* Dont test if low transformer version

* fix

* increase transformer

* increase as 4.57.0 is yanked

* remove pi from all

* go back

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
Pepijn
2025-12-18 12:50:32 +01:00
committed by GitHub
parent 4a151a9682
commit f04958527e
30 changed files with 6449 additions and 29 deletions
+67 -4
View File
@@ -62,6 +62,7 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -78,6 +79,7 @@ def update_policy(
accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
lock: An optional lock for thread-safe optimizer updates.
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
Returns:
A tuple containing:
@@ -87,9 +89,30 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
# Get RA-BC weights if enabled
rabc_batch_weights = None
rabc_batch_stats = None
if rabc_weights_provider is not None:
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None:
# Get per-sample losses
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
# rabc_batch_weights is already normalized to sum to batch_size
epsilon = 1e-6
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
# Log raw mean weight (before normalization) - this is the meaningful metric
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
else:
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
cfg.validate()
# Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
@@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process
cfg.validate()
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
@@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
@@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
rabc_weights = None
if cfg.use_rabc:
from lerobot.utils.rabc import RABCWeights
# Get chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError("Chunk size is not found in policy config")
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
rabc_weights = RABCWeights(
progress_path=cfg.rabc_progress_path,
chunk_size=chunk_size,
head_mode=head_mode,
kappa=getattr(cfg, "rabc_kappa", 0.01),
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
device=device,
)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
@@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
logging.info("Start offline training on a fixed dataset")
logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
)
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
@@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()