Added possiblity to record and replay delta actions during teleoperation rather than absolute actions

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-12 19:25:41 +01:00
committed by AdilZouitine
parent 4057904238
commit 9c14830cd9
11 changed files with 63 additions and 618 deletions
+14 -10
View File
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,6 +21,7 @@ import hydra
import numpy as np
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -32,7 +31,6 @@ from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from tqdm import tqdm
import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
@@ -45,6 +43,7 @@ from lerobot.common.utils.utils import (
init_hydra_config,
set_global_seed,
)
from lerobot.scripts.server.buffer import random_shift
def get_model(cfg, logger): # noqa I001
@@ -82,6 +81,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter()
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
images = [random_shift(img, 4) for img in images]
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
@@ -161,14 +161,17 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min( cfg.eval.num_samples_to_log - len(samples), len(images))):
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
**{f"image_{img_key}": wandb.Image(images[img_idx][i].cpu()) for img_idx, img_key in enumerate(cfg.training.image_keys)},
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
"confidence": confidence,
@@ -270,11 +273,13 @@ def train(cfg: DictConfig) -> None:
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "classifier"
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders
dataset = LeRobotDataset(cfg.dataset_repo_id)
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
)
logging.info(f"Dataset size: {len(dataset)}")
n_total = len(dataset)
@@ -282,14 +287,13 @@ def train(cfg: DictConfig) -> None:
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader(
train_dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
sampler=sampler,
pin_memory=True,
pin_memory=device.type == "cuda",
)
val_loader = DataLoader(
@@ -297,7 +301,7 @@ def train(cfg: DictConfig) -> None:
batch_size=cfg.eval.batch_size,
shuffle=False,
num_workers=cfg.training.num_workers,
pin_memory=True,
pin_memory=device.type == "cuda",
)
# Resume training if requested