Extend reward classifier for multiple camera views (#626)

This commit is contained in:
Michel Aractingi
2025-01-13 13:57:49 +01:00
committed by GitHub
parent c5bca1cf0f
commit 3bb5ed5e91
9 changed files with 192 additions and 49 deletions
+4 -3
View File
@@ -22,6 +22,7 @@ from pprint import pformat
import hydra
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
@@ -79,7 +79,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
pbar = tqdm(train_loader, desc="Training")
for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter()
images = batch[cfg.training.image_key].to(device)
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
@@ -130,7 +130,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device)
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device)
outputs = model(images)
@@ -163,6 +163,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
accuracy = 100 * correct / total
avg_loss = running_loss / len(val_loader)
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
eval_info = {
"loss": avg_loss,