mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def get_model(cfg, logger):
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
@@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
# Check if the device supports AMP
|
||||
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
@@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
@@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
samples = []
|
||||
running_loss = 0
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
@@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
Reference in New Issue
Block a user