[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by AdilZouitine
parent 2945bbb221
commit 7c05755823
123 changed files with 1161 additions and 3425 deletions
+16 -54
View File
@@ -52,19 +52,13 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
return model
def create_balanced_sampler(dataset, cfg):
# Get underlying dataset if using Subset
original_dataset = (
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
)
original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
# Get indices if using Subset (for slicing)
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
@@ -83,9 +77,7 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@@ -94,9 +86,7 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
@@ -110,11 +100,7 @@ def train_epoch(
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
@@ -159,9 +145,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@@ -174,9 +158,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
):
outputs = model(images)
inference_times.append(
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
else:
outputs = model(images)
@@ -194,24 +176,16 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(
images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
@@ -286,9 +260,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
_ = model(x)
inference_times.append(
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
inference_times = np.array(inference_times)
@@ -314,9 +286,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
def train(
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
) -> None:
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None:
if out_dir is None:
raise NotImplementedError()
if job_name is None:
@@ -372,9 +342,7 @@ def train(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
@@ -387,9 +355,7 @@ def train(
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
@@ -408,11 +374,7 @@ def train(
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters