[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov
2025-01-06 17:34:00 +07:00
committed by GitHub
parent 35de91ef2b
commit c5bca1cf0f
11 changed files with 59 additions and 19 deletions
+13 -4
View File
@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
)
def get_model(cfg, logger):
def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
@@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
# Check if the device supports AMP
# Here is an example of the issue that says that MPS doesn't support AMP properply
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
@@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
@@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
samples = []
running_loss = 0
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
@@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
return accuracy, eval_info
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))