File size: 5,635 Bytes
5904988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk
import os
from config import Config
from utils.tokenizer import build_vocab
from utils.preprocessing import collate_fn
from models.seq2seq import Encoder, Decoder, Seq2Seq
from tqdm import tqdm

def save_checkpoint(epoch, model, optimizer, scaler, loss, path):
    """Save training checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, path)
    print(f"βœ… Checkpoint saved at epoch {epoch}")

def load_checkpoint(model, optimizer, scaler, path, device):
    """Load training checkpoint"""
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint['loss']
        print(f"βœ… Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch, best_loss
    return 0, float('inf')  # Start from beginning if no checkpoint

def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg):
    model.train()
    total_loss = 0
    optimizer.zero_grad()  # Zero gradients at start
    
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)

    for batch_idx, (src, trg) in enumerate(loop):
        src, trg = src.to(device), trg.to(device)

        # Mixed precision training
        with torch.cuda.amp.autocast(enabled=cfg.use_amp):
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output[1:].reshape(-1, output_dim)
            trg = trg[1:].reshape(-1)
            loss = criterion(output, trg) / cfg.gradient_accumulation_steps  # Normalize loss

        scaler.scale(loss).backward()

        # Gradient accumulation
        if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0:
            if cfg.use_gradient_clipping:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
            
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        total_loss += loss.item() * cfg.gradient_accumulation_steps
        loop.set_postfix(loss=loss.item() * cfg.gradient_accumulation_steps)

    return total_loss / len(dataloader)



def main():
    cfg = Config()
    device = cfg.device
    print(f"Using device: {device}")

    # Create directories if they don't exist
    os.makedirs("models", exist_ok=True)
    os.makedirs("models/tokenizers", exist_ok=True)

    # Load dataset (full dataset)
    dataset = load_from_disk("data/raw/")

    # Build vocab using full dataset
    src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang)
    trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang)

    # Save tokenizers and vocab for future use
    torch.save({
        'src_tokenizer': src_tokenizer,
        'src_vocab': src_vocab,
        'trg_tokenizer': trg_tokenizer,
        'trg_vocab': trg_vocab
    }, cfg.tokenizer_save_path + "tokenizers.pth")

    # DataLoader with train split
    collate = lambda batch: collate_fn(
        batch, src_tokenizer, trg_tokenizer, src_vocab, trg_vocab, cfg.max_length,
        src_lang=cfg.source_lang, trg_lang=cfg.target_lang
    )
    dataloader = DataLoader(dataset["train"], batch_size=cfg.batch_size, collate_fn=collate, shuffle=True)

    # Model
    enc = Encoder(len(src_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
    dec = Decoder(len(trg_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
    model = Seq2Seq(enc, dec, device).to(device)

    optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"])
    scaler = torch.cuda.amp.GradScaler()

    # Try to load checkpoint
    start_epoch, best_loss = load_checkpoint(model, optimizer, scaler, cfg.checkpoint_path, device)

    for epoch in range(start_epoch, cfg.num_epochs):
        print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
        
        try:
            loss = train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg)
            print(f"Epoch {epoch+1}/{cfg.num_epochs} | Loss: {loss:.3f}")

            # Save checkpoint after each epoch
            save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)

            # Save best model
            if loss < best_loss:
                best_loss = loss
                torch.save(model.state_dict(), cfg.best_model_path)
                print(f"πŸŽ‰ New best model saved with loss: {loss:.3f}")

        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                print("⚠️ GPU out of memory. Saving checkpoint and exiting...")
                save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
                print("βœ… Checkpoint saved. You can resume training later.")
                break
            else:
                raise e

    print("βœ… Training completed!")

if __name__ == "__main__":
    main()