mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Extend reward classifier for multiple camera views (#626)
This commit is contained in:
committed by
AdilZouitine
parent
b68730474a
commit
6139df553d
@@ -22,6 +22,7 @@ from pprint import pformat
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -30,7 +31,6 @@ from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
@@ -79,7 +79,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
@@ -130,7 +130,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
outputs = model(images)
|
||||
@@ -163,6 +163,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
|
||||
Reference in New Issue
Block a user