[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions
+80 -22
View File
@@ -14,7 +14,6 @@
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
import hydra
@@ -28,14 +27,16 @@ from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
@@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
return model
@@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
@@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
outputs = model(images)
loss = criterion(outputs.logits, labels)
@@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
else:
outputs = model(images)
@@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
f"image_{img_key}": wandb.Image(
images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
@@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
avg, median, std = (
inference_times.mean(),
np.median(inference_times),
inference_times.std(),
)
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
@@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
@hydra.main(
version_base="1.2",
config_path="../configs/policy",
config_name="hilserl_classifier",
)
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))
@@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None:
# Setup dataset and dataloaders
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
cfg.dataset_repo_id,
root=cfg.dataset_root,
local_files_only=cfg.local_files_only,
)
logging.info(f"Dataset size: {len(dataset)}")
@@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None:
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
@@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None:
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
@@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None:
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters
@@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None:
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0: