Spaces:
Sleeping
Sleeping
File size: 5,635 Bytes
5904988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk
import os
from config import Config
from utils.tokenizer import build_vocab
from utils.preprocessing import collate_fn
from models.seq2seq import Encoder, Decoder, Seq2Seq
from tqdm import tqdm
def save_checkpoint(epoch, model, optimizer, scaler, loss, path):
"""Save training checkpoint"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'loss': loss,
}
torch.save(checkpoint, path)
print(f"β
Checkpoint saved at epoch {epoch}")
def load_checkpoint(model, optimizer, scaler, path, device):
"""Load training checkpoint"""
if os.path.exists(path):
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint['loss']
print(f"β
Checkpoint loaded. Resuming from epoch {start_epoch}")
return start_epoch, best_loss
return 0, float('inf') # Start from beginning if no checkpoint
def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg):
model.train()
total_loss = 0
optimizer.zero_grad() # Zero gradients at start
loop = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
for batch_idx, (src, trg) in enumerate(loop):
src, trg = src.to(device), trg.to(device)
# Mixed precision training
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
output = model(src, trg)
output_dim = output.shape[-1]
output = output[1:].reshape(-1, output_dim)
trg = trg[1:].reshape(-1)
loss = criterion(output, trg) / cfg.gradient_accumulation_steps # Normalize loss
scaler.scale(loss).backward()
# Gradient accumulation
if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0:
if cfg.use_gradient_clipping:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
total_loss += loss.item() * cfg.gradient_accumulation_steps
loop.set_postfix(loss=loss.item() * cfg.gradient_accumulation_steps)
return total_loss / len(dataloader)
def main():
cfg = Config()
device = cfg.device
print(f"Using device: {device}")
# Create directories if they don't exist
os.makedirs("models", exist_ok=True)
os.makedirs("models/tokenizers", exist_ok=True)
# Load dataset (full dataset)
dataset = load_from_disk("data/raw/")
# Build vocab using full dataset
src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang)
trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang)
# Save tokenizers and vocab for future use
torch.save({
'src_tokenizer': src_tokenizer,
'src_vocab': src_vocab,
'trg_tokenizer': trg_tokenizer,
'trg_vocab': trg_vocab
}, cfg.tokenizer_save_path + "tokenizers.pth")
# DataLoader with train split
collate = lambda batch: collate_fn(
batch, src_tokenizer, trg_tokenizer, src_vocab, trg_vocab, cfg.max_length,
src_lang=cfg.source_lang, trg_lang=cfg.target_lang
)
dataloader = DataLoader(dataset["train"], batch_size=cfg.batch_size, collate_fn=collate, shuffle=True)
# Model
enc = Encoder(len(src_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
dec = Decoder(len(trg_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
model = Seq2Seq(enc, dec, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"])
scaler = torch.cuda.amp.GradScaler()
# Try to load checkpoint
start_epoch, best_loss = load_checkpoint(model, optimizer, scaler, cfg.checkpoint_path, device)
for epoch in range(start_epoch, cfg.num_epochs):
print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
try:
loss = train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg)
print(f"Epoch {epoch+1}/{cfg.num_epochs} | Loss: {loss:.3f}")
# Save checkpoint after each epoch
save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
# Save best model
if loss < best_loss:
best_loss = loss
torch.save(model.state_dict(), cfg.best_model_path)
print(f"π New best model saved with loss: {loss:.3f}")
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print("β οΈ GPU out of memory. Saving checkpoint and exiting...")
save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
print("β
Checkpoint saved. You can resume training later.")
break
else:
raise e
print("β
Training completed!")
if __name__ == "__main__":
main() |