Spaces:
Sleeping
Sleeping
| import argparse | |
| import torch | |
| import wandb | |
| from torch import nn, optim | |
| from torch.nn.functional import cosine_similarity | |
| from torch.optim import lr_scheduler | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from typing_extensions import Optional | |
| from src.dataset import RandomAugmentedDataset, get_byol_transforms | |
| from src.models import BYOL | |
| def get_data_loaders( | |
| batch_size: int, | |
| num_train_samples: int, | |
| num_val_samples: int, | |
| shape_params: Optional[dict] = None, | |
| num_workers: int = 0 | |
| ): | |
| augmentations = get_byol_transforms() | |
| train_dataset = RandomAugmentedDataset( | |
| augmentations, | |
| shape_params, | |
| num_samples=num_train_samples, | |
| train=True | |
| ) | |
| val_dataset = RandomAugmentedDataset( | |
| augmentations, | |
| shape_params, | |
| num_samples=num_val_samples, | |
| train=False | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers | |
| ) | |
| return train_loader, val_loader | |
| def build_model(lr: float): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = BYOL().to(device) | |
| optimizer = optim.Adam( | |
| list(model.online_network.parameters()) + list(model.online_predictor.parameters()), | |
| lr=lr | |
| ) | |
| scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2) | |
| return model, optimizer, scheduler, device | |
| def train_epoch( | |
| model: nn.Module, | |
| optimizer: optim.Optimizer, | |
| train_loader: DataLoader, | |
| device: torch.device | |
| ) -> dict: | |
| model.train() | |
| running_train_loss = 0.0 | |
| total_cos_sim, total_l2_dist, total_feat_norm, total_grad_norm = 0.0, 0.0, 0.0, 0.0 | |
| num_train_batches = 0 | |
| for (view_1, view_2) in tqdm(train_loader, desc="Training"): | |
| view_1 = view_1.to(device) | |
| view_2 = view_2.to(device) | |
| loss = model.loss(view_1, view_2) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| with torch.no_grad(): | |
| online_proj1, target_proj1 = model(view_1) | |
| online_proj2, target_proj2 = model(view_2) | |
| cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item() | |
| l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item() | |
| feat_norm = torch.norm(online_proj1, dim=-1).mean().item() | |
| grad_norm = torch.norm( | |
| torch.cat([ | |
| p.grad.flatten() | |
| for p in model.online_network.parameters() | |
| if p.grad is not None | |
| ]) | |
| ).item() | |
| total_cos_sim += cos_sim | |
| total_l2_dist += l2_dist | |
| total_feat_norm += feat_norm | |
| total_grad_norm += grad_norm | |
| optimizer.step() | |
| model.soft_update_target_network() | |
| running_train_loss += loss.item() | |
| num_train_batches += 1 | |
| train_loss = running_train_loss / num_train_batches | |
| train_cos_sim = total_cos_sim / num_train_batches | |
| train_l2_dist = total_l2_dist / num_train_batches | |
| train_feat_norm = total_feat_norm / num_train_batches | |
| train_grad_norm = total_grad_norm / num_train_batches | |
| return { | |
| "loss": train_loss, | |
| "cos_sim": train_cos_sim, | |
| "l2_dist": train_l2_dist, | |
| "feat_norm": train_feat_norm, | |
| "grad_norm": train_grad_norm, | |
| } | |
| def validate( | |
| model: nn.Module, | |
| val_loader: DataLoader, | |
| device: torch.device | |
| ) -> dict: | |
| model.eval() | |
| running_val_loss = 0.0 | |
| total_cos_sim, total_l2_dist, total_feat_norm = 0.0, 0.0, 0.0 | |
| num_val_batches = 0 | |
| for (view_1, view_2) in tqdm(val_loader, desc="Validation"): | |
| view_1 = view_1.to(device) | |
| view_2 = view_2.to(device) | |
| loss = model.loss(view_1, view_2) | |
| running_val_loss += loss.item() | |
| online_proj1, target_proj1 = model(view_1) | |
| online_proj2, target_proj2 = model(view_2) | |
| cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item() | |
| l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item() | |
| feat_norm = torch.norm(online_proj1, dim=-1).mean().item() | |
| total_cos_sim += cos_sim | |
| total_l2_dist += l2_dist | |
| total_feat_norm += feat_norm | |
| num_val_batches += 1 | |
| val_loss = running_val_loss / num_val_batches | |
| val_cos_sim = total_cos_sim / num_val_batches | |
| val_l2_dist = total_l2_dist / num_val_batches | |
| val_feat_norm = total_feat_norm / num_val_batches | |
| return { | |
| "loss": val_loss, | |
| "cos_sim": val_cos_sim, | |
| "l2_dist": val_l2_dist, | |
| "feat_norm": val_feat_norm | |
| } | |
| def train( | |
| model: nn.Module, | |
| optimizer: optim.Optimizer, | |
| scheduler, | |
| device: torch.device, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| num_epochs: int, | |
| early_stopping_patience: int = 3, | |
| save_path: str = "best_byol.pth" | |
| ): | |
| best_loss = float("inf") | |
| epochs_no_improve = 0 | |
| print("Start training...") | |
| for epoch in range(num_epochs): | |
| print(f"Epoch {epoch + 1}/{num_epochs}") | |
| train_metrics = train_epoch(model, optimizer, train_loader, device) | |
| val_metrics = validate(model, val_loader, device) | |
| wandb.log({ | |
| "epoch": epoch + 1, | |
| "train_loss": train_metrics["loss"], | |
| "train_cos_sim": train_metrics["cos_sim"], | |
| "train_l2_dist": train_metrics["l2_dist"], | |
| "train_feat_norm": train_metrics["feat_norm"], | |
| "train_grad_norm": train_metrics["grad_norm"], | |
| "val_loss": val_metrics["loss"], | |
| "val_cos_sim": val_metrics["cos_sim"], | |
| "val_l2_dist": val_metrics["l2_dist"], | |
| "val_feat_norm": val_metrics["feat_norm"], | |
| }) | |
| print( | |
| f"Train Loss: {train_metrics['loss']:.4f} | " | |
| f"CosSim: {train_metrics['cos_sim']:.4f} | " | |
| f"L2Dist: {train_metrics['l2_dist']:.4f}" | |
| ) | |
| print( | |
| f"Val Loss: {val_metrics['loss']:.4f} | " | |
| f"CosSim: {val_metrics['cos_sim']:.4f} | " | |
| f"L2Dist: {val_metrics['l2_dist']:.4f}" | |
| ) | |
| current_val_loss = val_metrics["loss"] | |
| if current_val_loss < best_loss or val_metrics['cos_sim'] >= 0.86: | |
| best_loss = current_val_loss | |
| encoder_state_dict = model.online_network.encoder.state_dict() | |
| torch.save(encoder_state_dict, save_path) | |
| epochs_no_improve = 0 | |
| else: | |
| epochs_no_improve += 1 | |
| scheduler.step(val_metrics["cos_sim"]) | |
| if epochs_no_improve >= early_stopping_patience: | |
| print(f"Early stopping on epoch {epoch + 1}") | |
| break | |
| def main(config: dict): | |
| wandb.init(project="contrastive_learning_byol", config=config) | |
| train_loader, val_loader = get_data_loaders( | |
| batch_size=config["batch_size"], | |
| num_train_samples=config["num_train_samples"], | |
| num_val_samples=config["num_val_samples"], | |
| shape_params=config["shape_params"] | |
| ) | |
| model, optimizer, scheduler, device = build_model( | |
| lr=config["lr"] | |
| ) | |
| train( | |
| model=model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| device=device, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| num_epochs=config["num_epochs"], | |
| early_stopping_patience=config["early_stopping_patience"], | |
| save_path=config["save_path"] | |
| ) | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train BYOL model") | |
| parser.add_argument("--batch_size", type=int, default=512) | |
| parser.add_argument("--lr", type=float, default=5e-4) | |
| parser.add_argument("--num_epochs", type=int, default=15) | |
| parser.add_argument("--num_train_samples", type=int, default=100000) | |
| parser.add_argument("--num_val_samples", type=int, default=10000) | |
| parser.add_argument("--random_intensity", type=int, default=1) | |
| parser.add_argument("--early_stopping_patience", type=int, default=3) | |
| parser.add_argument("--save_path", type=str, default="best_byol.pth") | |
| args = parser.parse_args() | |
| config = { | |
| "batch_size": args.batch_size, | |
| "lr": args.lr, | |
| "num_epochs": args.num_epochs, | |
| "num_train_samples": args.num_train_samples, | |
| "num_val_samples": args.num_val_samples, | |
| "shape_params": { | |
| "random_intensity": bool(args.random_intensity) | |
| }, | |
| "early_stopping_patience": args.early_stopping_patience, | |
| "save_path": args.save_path | |
| } | |
| main(config) | |