Spaces:
Sleeping
Sleeping
| import argparse | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import wandb | |
| from torch.optim.lr_scheduler import StepLR | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from typing_extensions import Optional | |
| from src.dataset import RandomPairDataset | |
| from src.models import CrossAttentionClassifier, VGGLikeEncode | |
| def visualize_attention(attn_heatmap, epoch: int): | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| im = ax.imshow(attn_heatmap, cmap="hot", interpolation="nearest") | |
| plt.colorbar(im, fraction=0.046, pad=0.04) | |
| plt.title(f"Attention Heatmap (Flatten 64x64) | Epoch {epoch}") | |
| wandb.log({"Flatten Attention Heatmap": wandb.Image(fig, caption=f"Flatten 64x64 | Epoch {epoch}")}) | |
| plt.close(fig) | |
| def get_data_loaders( | |
| num_train_samples: int, | |
| num_val_samples: int, | |
| batch_size: int, | |
| num_workers: int = 0, | |
| shape_params: Optional[dict] = None, | |
| ): | |
| train_dataset = RandomPairDataset( | |
| shape_params=shape_params, | |
| num_samples=num_train_samples, | |
| train=True | |
| ) | |
| val_dataset = RandomPairDataset( | |
| shape_params=shape_params, | |
| num_samples=num_val_samples, | |
| train=False | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers | |
| ) | |
| return train_loader, val_loader | |
| def build_model( | |
| path_to_encoder: str, | |
| lr: float, | |
| weight_decay: float, | |
| step_size: int, | |
| gamma: float, | |
| device: torch.device | |
| ): | |
| encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=False) | |
| encoder.load_state_dict(torch.load(path_to_encoder)) | |
| model = CrossAttentionClassifier(encoder=encoder) | |
| model = model.to(device) | |
| criterion = nn.BCEWithLogitsLoss() | |
| optimizer = optim.Adam( | |
| model.parameters(), | |
| lr=lr, | |
| weight_decay=weight_decay | |
| ) | |
| scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) | |
| return model, criterion, optimizer, scheduler | |
| def train_epoch( | |
| model: nn.Module, | |
| criterion: nn.Module, | |
| optimizer: optim.Optimizer, | |
| train_loader: DataLoader, | |
| device: torch.device | |
| ): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for img1, img2, labels in tqdm(train_loader, desc="Training", leave=False): | |
| img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| logits, attn_weights = model(img1, img2) | |
| loss = criterion(logits, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * img1.size(0) | |
| preds = (torch.sigmoid(logits) > 0.5).float() | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| epoch_loss = running_loss / len(train_loader.dataset) | |
| epoch_acc = correct / total | |
| return epoch_loss, epoch_acc | |
| def validate( | |
| model: nn.Module, | |
| criterion: nn.Module, | |
| val_loader: DataLoader, | |
| device: torch.device | |
| ): | |
| model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for img1, img2, labels in tqdm(val_loader, desc="Validation", leave=False): | |
| img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) | |
| logits, attn_weights = model(img1, img2) | |
| loss = criterion(logits, labels) | |
| running_loss += loss.item() * img1.size(0) | |
| preds = (torch.sigmoid(logits) > 0.5).float() | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| epoch_loss = running_loss / len(val_loader.dataset) | |
| epoch_acc = correct / total | |
| return epoch_loss, epoch_acc | |
| def train( | |
| model: nn.Module, | |
| criterion: nn.Module, | |
| optimizer: optim.Optimizer, | |
| scheduler, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| device: torch.device, | |
| num_epochs: int = 30, | |
| save_path: str = "best_attention_classifier.pth" | |
| ): | |
| best_val_loss = float("inf") | |
| epochs_no_improve = 0 | |
| print("Start training...") | |
| for epoch in range(num_epochs): | |
| print(f"Epoch {epoch + 1}/{num_epochs}") | |
| train_loss, train_acc = train_epoch(model, criterion, optimizer, train_loader, device) | |
| val_loss, val_acc = validate(model, criterion, val_loader, device) | |
| scheduler.step() | |
| wandb.log({ | |
| "epoch": epoch + 1, | |
| "train_loss": train_loss, | |
| "train_acc": train_acc, | |
| "val_loss": val_loss, | |
| "val_acc": val_acc, | |
| "lr": optimizer.param_groups[0]["lr"], | |
| }) | |
| print( | |
| f"learning rate: {optimizer.param_groups[0]['lr']:.6f}, " | |
| f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " | |
| f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}" | |
| ) | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| torch.save(model.state_dict(), save_path) | |
| epochs_no_improve = 0 | |
| else: | |
| epochs_no_improve += 1 | |
| with torch.no_grad(): | |
| sample_img1, sample_img2, sample_labels = next(iter(val_loader)) | |
| sample_img1, sample_img2 = sample_img1.to(device), sample_img2.to(device) | |
| _, sample_attn_weights = model(sample_img1, sample_img2) | |
| wandb.log({ | |
| "attention_std": sample_attn_weights.std().item(), | |
| "attention_mean": sample_attn_weights.mean().item(), | |
| }) | |
| attn_heatmap = sample_attn_weights[0].detach().cpu().numpy() | |
| visualize_attention(attn_heatmap, epoch) | |
| def main(config): | |
| wandb.init(project="cross_attention_classifier", config=config) | |
| train_loader, val_loader = get_data_loaders( | |
| shape_params=config["shape_params"], | |
| num_train_samples=config["num_train_samples"], | |
| num_val_samples=config["num_val_samples"], | |
| batch_size=config["batch_size"] | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model, criterion, optimizer, scheduler = build_model( | |
| path_to_encoder=config["path_to_encoder"], | |
| lr=config["lr"], | |
| weight_decay=config["weight_decay"], | |
| step_size=config["step_size"], | |
| gamma=config["gamma"], | |
| device=device | |
| ) | |
| train( | |
| model=model, | |
| criterion=criterion, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| device=device, | |
| num_epochs=config["num_epochs"], | |
| save_path=config["save_path"] | |
| ) | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train classifier model") | |
| parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth") | |
| parser.add_argument("--batch_size", type=int, default=256) | |
| parser.add_argument("--lr", type=float, default=8e-5) | |
| parser.add_argument("--weight_decay", type=float, default=1e-4) | |
| parser.add_argument("--step_size", type=int, default=10) | |
| parser.add_argument("--gamma", type=float, default=0.1) | |
| parser.add_argument("--num_epochs", type=int, default=10) | |
| parser.add_argument("--num_train_samples", type=int, default=10000) | |
| parser.add_argument("--num_val_samples", type=int, default=2000) | |
| parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth") | |
| args = parser.parse_args() | |
| config = { | |
| "path_to_encoder": args.path_to_encoder, | |
| "batch_size": args.batch_size, | |
| "lr": args.lr, | |
| "weight_decay": args.weight_decay, | |
| "step_size": args.step_size, | |
| "gamma": args.gamma, | |
| "num_epochs": args.num_epochs, | |
| "num_train_samples": args.num_train_samples, | |
| "num_val_samples": args.num_val_samples, | |
| "save_path": args.save_path, | |
| } | |
| if "shape_params" not in config: | |
| config["shape_params"] = {} | |
| main(config) | |