Upload 6 files
Browse files- app_2.py +130 -0
- config.py +42 -0
- inference.py +37 -0
- requirements.txt +6 -0
- resume_training.py +14 -0
- train.py +146 -0
app_2.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# Load the model and tokenizer
|
| 6 |
+
model_name = "jbochi/madlad400-3b-mt"
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 8 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 9 |
+
model_name,
|
| 10 |
+
torch_dtype=torch.float16,
|
| 11 |
+
device_map="auto"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
def translate_text(text, source_lang, target_lang):
|
| 15 |
+
"""
|
| 16 |
+
Translate text between English and Persian using MADLAD-400-3B
|
| 17 |
+
"""
|
| 18 |
+
# Define language codes for the model
|
| 19 |
+
lang_codes = {
|
| 20 |
+
"English": "en",
|
| 21 |
+
"Persian": "fa"
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
source_code = lang_codes[source_lang]
|
| 25 |
+
target_code = lang_codes[target_lang]
|
| 26 |
+
|
| 27 |
+
# Create the translation prompt in the format the model expects
|
| 28 |
+
prompt = f"<2{target_code}> {text}"
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
# Tokenize input
|
| 32 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 33 |
+
|
| 34 |
+
# Move inputs to the same device as model
|
| 35 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 36 |
+
|
| 37 |
+
# Generate translation
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
outputs = model.generate(
|
| 40 |
+
**inputs,
|
| 41 |
+
max_length=512,
|
| 42 |
+
num_beams=5,
|
| 43 |
+
early_stopping=True,
|
| 44 |
+
no_repeat_ngram_size=3,
|
| 45 |
+
length_penalty=1.0
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Decode the output
|
| 49 |
+
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 50 |
+
|
| 51 |
+
return translated_text
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
return f"Error during translation: {str(e)}"
|
| 55 |
+
|
| 56 |
+
# Create the Gradio interface
|
| 57 |
+
with gr.Blocks(title="English-Persian Translator") as demo:
|
| 58 |
+
gr.Markdown(
|
| 59 |
+
"""
|
| 60 |
+
# 🌍 English-Persian Translator
|
| 61 |
+
**Powered by MADLAD-400-3B Model**
|
| 62 |
+
|
| 63 |
+
Translate text between English and Persian using the state-of-the-art MADLAD-400 model.
|
| 64 |
+
"""
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
with gr.Row():
|
| 68 |
+
with gr.Column():
|
| 69 |
+
source_lang = gr.Dropdown(
|
| 70 |
+
choices=["English", "Persian"],
|
| 71 |
+
value="English",
|
| 72 |
+
label="Source Language"
|
| 73 |
+
)
|
| 74 |
+
input_text = gr.Textbox(
|
| 75 |
+
lines=5,
|
| 76 |
+
placeholder="Enter text to translate...",
|
| 77 |
+
label="Input Text"
|
| 78 |
+
)
|
| 79 |
+
translate_btn = gr.Button("Translate", variant="primary")
|
| 80 |
+
|
| 81 |
+
with gr.Column():
|
| 82 |
+
target_lang = gr.Dropdown(
|
| 83 |
+
choices=["Persian", "English"],
|
| 84 |
+
value="Persian",
|
| 85 |
+
label="Target Language"
|
| 86 |
+
)
|
| 87 |
+
output_text = gr.Textbox(
|
| 88 |
+
lines=5,
|
| 89 |
+
label="Translated Text",
|
| 90 |
+
interactive=False
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Examples
|
| 94 |
+
gr.Examples(
|
| 95 |
+
examples=[
|
| 96 |
+
["Hello, how are you today?", "English", "Persian"],
|
| 97 |
+
["What is your name?", "English", "Persian"],
|
| 98 |
+
["سلام، حالتون چطوره؟", "Persian", "English"],
|
| 99 |
+
["امروز هوا خوب است", "Persian", "English"]
|
| 100 |
+
],
|
| 101 |
+
inputs=[input_text, source_lang, target_lang],
|
| 102 |
+
outputs=output_text,
|
| 103 |
+
fn=translate_text,
|
| 104 |
+
cache_examples=False
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Connect the button
|
| 108 |
+
translate_btn.click(
|
| 109 |
+
fn=translate_text,
|
| 110 |
+
inputs=[input_text, source_lang, target_lang],
|
| 111 |
+
outputs=output_text
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Auto-update target language based on source selection
|
| 115 |
+
def update_target_lang(source_lang):
|
| 116 |
+
return "Persian" if source_lang == "English" else "English"
|
| 117 |
+
|
| 118 |
+
source_lang.change(
|
| 119 |
+
fn=update_target_lang,
|
| 120 |
+
inputs=source_lang,
|
| 121 |
+
outputs=target_lang
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
# Launch the app
|
| 126 |
+
demo.launch(
|
| 127 |
+
server_name="0.0.0.0", # Allow external access
|
| 128 |
+
share=False, # Set to True to get a public URL
|
| 129 |
+
debug=True
|
| 130 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
class Config:
|
| 3 |
+
# Data
|
| 4 |
+
dataset_name = "ParsBench/parsinlu-machine-translation-fa-en-alpaca-style"
|
| 5 |
+
source_lang = "instruction" # English
|
| 6 |
+
target_lang = "output" # Persian
|
| 7 |
+
max_length = 32
|
| 8 |
+
batch_size = 24
|
| 9 |
+
|
| 10 |
+
# Model
|
| 11 |
+
input_dim = 5000 # Vocabulary size for English
|
| 12 |
+
output_dim = 5000 # Vocabulary size for Persian
|
| 13 |
+
embedding_dim = 64 # Word vector dimensions
|
| 14 |
+
hidden_dim = 128 # LSTM hidden state size
|
| 15 |
+
num_layers = 1 # Stacked LSTM layers
|
| 16 |
+
dropout = 0.1 # Regularization to prevent overfitting
|
| 17 |
+
|
| 18 |
+
# Training
|
| 19 |
+
learning_rate = 0.001
|
| 20 |
+
num_epochs = 5
|
| 21 |
+
teacher_forcing_ratio = 0.7 # Mix of ground truth vs model predictions
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Optimization
|
| 25 |
+
gradient_accumulation_steps = 1
|
| 26 |
+
use_amp = True # Mixed precision for speed
|
| 27 |
+
use_gradient_clipping = True
|
| 28 |
+
max_grad_norm = 1.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Checkpoint Configuration =====
|
| 32 |
+
checkpoint_interval = 1 # Save every 2 epochs
|
| 33 |
+
save_best_only = True # Only save when model improves
|
| 34 |
+
|
| 35 |
+
# Device
|
| 36 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 37 |
+
|
| 38 |
+
# Paths
|
| 39 |
+
model_save_path = "models/seq2seq_model.pth"
|
| 40 |
+
tokenizer_save_path = "models/tokenizers/"
|
| 41 |
+
checkpoint_path = "models/checkpoint.pth"
|
| 42 |
+
best_model_path = "models/best_model.pth"
|
inference.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from config import Config
|
| 3 |
+
from models.seq2seq import Encoder, Decoder, Seq2Seq
|
| 4 |
+
from utils.tokenizer import build_vocab
|
| 5 |
+
from datasets import load_from_disk
|
| 6 |
+
|
| 7 |
+
def translate_sentence(sentence, model, src_tokenizer, src_vocab, trg_vocab, device, max_len=30):
|
| 8 |
+
model.eval()
|
| 9 |
+
tokens = src_tokenizer(sentence.lower())
|
| 10 |
+
src_tensor = torch.tensor([src_vocab["<sos>"]] + [src_vocab[t] for t in tokens] + [src_vocab["<eos>"]]).unsqueeze(1).to(device)
|
| 11 |
+
with torch.no_grad():
|
| 12 |
+
hidden = model.encoder(src_tensor)
|
| 13 |
+
trg_indexes = [trg_vocab["<sos>"]]
|
| 14 |
+
|
| 15 |
+
for _ in range(max_len):
|
| 16 |
+
trg_tensor = torch.tensor([trg_indexes[-1]]).to(device)
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
output, hidden = model.decoder(trg_tensor, hidden)
|
| 19 |
+
pred_token = output.argmax(1).item()
|
| 20 |
+
trg_indexes.append(pred_token)
|
| 21 |
+
if pred_token == trg_vocab["<eos>"]:
|
| 22 |
+
break
|
| 23 |
+
return [trg_vocab.get_itos()[i] for i in trg_indexes]
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
cfg = Config()
|
| 27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
+
dataset = load_from_disk("data/raw/")
|
| 29 |
+
src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang)
|
| 30 |
+
trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang)
|
| 31 |
+
|
| 32 |
+
enc = Encoder(len(src_vocab), cfg.emb_dim, cfg.hid_dim, cfg.n_layers)
|
| 33 |
+
dec = Decoder(len(trg_vocab), cfg.emb_dim, cfg.hid_dim, cfg.n_layers)
|
| 34 |
+
model = Seq2Seq(enc, dec, device).to(device)
|
| 35 |
+
model.load_state_dict(torch.load(cfg.model_save_path, map_location=device))
|
| 36 |
+
|
| 37 |
+
print(translate_sentence("I love cats", model, src_tokenizer, src_vocab, trg_vocab, device))
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.9.0
|
| 2 |
+
torchtext>=0.10.0
|
| 3 |
+
datasets>=2.14.0
|
| 4 |
+
numpy>=1.21.0
|
| 5 |
+
tqdm>=4.62.0
|
| 6 |
+
streamlit>=1.22.0
|
resume_training.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from train import main
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
print("🔄 Resuming training from checkpoint...")
|
| 7 |
+
|
| 8 |
+
# Check if checkpoint exists
|
| 9 |
+
if not os.path.exists("models/checkpoint.pth"):
|
| 10 |
+
print("❌ No checkpoint found. Starting fresh training...")
|
| 11 |
+
else:
|
| 12 |
+
print("✅ Checkpoint found. Resuming...")
|
| 13 |
+
|
| 14 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from datasets import load_from_disk
|
| 6 |
+
import os
|
| 7 |
+
from config import Config
|
| 8 |
+
from utils.tokenizer import build_vocab
|
| 9 |
+
from utils.preprocessing import collate_fn
|
| 10 |
+
from models.seq2seq import Encoder, Decoder, Seq2Seq
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
def save_checkpoint(epoch, model, optimizer, scaler, loss, path):
|
| 14 |
+
"""Save training checkpoint"""
|
| 15 |
+
checkpoint = {
|
| 16 |
+
'epoch': epoch,
|
| 17 |
+
'model_state_dict': model.state_dict(),
|
| 18 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 19 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 20 |
+
'loss': loss,
|
| 21 |
+
}
|
| 22 |
+
torch.save(checkpoint, path)
|
| 23 |
+
print(f"✅ Checkpoint saved at epoch {epoch}")
|
| 24 |
+
|
| 25 |
+
def load_checkpoint(model, optimizer, scaler, path, device):
|
| 26 |
+
"""Load training checkpoint"""
|
| 27 |
+
if os.path.exists(path):
|
| 28 |
+
checkpoint = torch.load(path, map_location=device)
|
| 29 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 30 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 31 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 32 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 33 |
+
best_loss = checkpoint['loss']
|
| 34 |
+
print(f"✅ Checkpoint loaded. Resuming from epoch {start_epoch}")
|
| 35 |
+
return start_epoch, best_loss
|
| 36 |
+
return 0, float('inf') # Start from beginning if no checkpoint
|
| 37 |
+
|
| 38 |
+
def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg):
|
| 39 |
+
model.train()
|
| 40 |
+
total_loss = 0
|
| 41 |
+
optimizer.zero_grad() # Zero gradients at start
|
| 42 |
+
|
| 43 |
+
loop = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
|
| 44 |
+
|
| 45 |
+
for batch_idx, (src, trg) in enumerate(loop):
|
| 46 |
+
src, trg = src.to(device), trg.to(device)
|
| 47 |
+
|
| 48 |
+
# Mixed precision training
|
| 49 |
+
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
|
| 50 |
+
output = model(src, trg)
|
| 51 |
+
output_dim = output.shape[-1]
|
| 52 |
+
output = output[1:].reshape(-1, output_dim)
|
| 53 |
+
trg = trg[1:].reshape(-1)
|
| 54 |
+
loss = criterion(output, trg) / cfg.gradient_accumulation_steps # Normalize loss
|
| 55 |
+
|
| 56 |
+
scaler.scale(loss).backward()
|
| 57 |
+
|
| 58 |
+
# Gradient accumulation
|
| 59 |
+
if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0:
|
| 60 |
+
if cfg.use_gradient_clipping:
|
| 61 |
+
scaler.unscale_(optimizer)
|
| 62 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
|
| 63 |
+
|
| 64 |
+
scaler.step(optimizer)
|
| 65 |
+
scaler.update()
|
| 66 |
+
optimizer.zero_grad()
|
| 67 |
+
|
| 68 |
+
total_loss += loss.item() * cfg.gradient_accumulation_steps
|
| 69 |
+
loop.set_postfix(loss=loss.item() * cfg.gradient_accumulation_steps)
|
| 70 |
+
|
| 71 |
+
return total_loss / len(dataloader)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
cfg = Config()
|
| 77 |
+
device = cfg.device
|
| 78 |
+
print(f"Using device: {device}")
|
| 79 |
+
|
| 80 |
+
# Create directories if they don't exist
|
| 81 |
+
os.makedirs("models", exist_ok=True)
|
| 82 |
+
os.makedirs("models/tokenizers", exist_ok=True)
|
| 83 |
+
|
| 84 |
+
# Load dataset (full dataset)
|
| 85 |
+
dataset = load_from_disk("data/raw/")
|
| 86 |
+
|
| 87 |
+
# Build vocab using full dataset
|
| 88 |
+
src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang)
|
| 89 |
+
trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang)
|
| 90 |
+
|
| 91 |
+
# Save tokenizers and vocab for future use
|
| 92 |
+
torch.save({
|
| 93 |
+
'src_tokenizer': src_tokenizer,
|
| 94 |
+
'src_vocab': src_vocab,
|
| 95 |
+
'trg_tokenizer': trg_tokenizer,
|
| 96 |
+
'trg_vocab': trg_vocab
|
| 97 |
+
}, cfg.tokenizer_save_path + "tokenizers.pth")
|
| 98 |
+
|
| 99 |
+
# DataLoader with train split
|
| 100 |
+
collate = lambda batch: collate_fn(
|
| 101 |
+
batch, src_tokenizer, trg_tokenizer, src_vocab, trg_vocab, cfg.max_length,
|
| 102 |
+
src_lang=cfg.source_lang, trg_lang=cfg.target_lang
|
| 103 |
+
)
|
| 104 |
+
dataloader = DataLoader(dataset["train"], batch_size=cfg.batch_size, collate_fn=collate, shuffle=True)
|
| 105 |
+
|
| 106 |
+
# Model
|
| 107 |
+
enc = Encoder(len(src_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
|
| 108 |
+
dec = Decoder(len(trg_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
|
| 109 |
+
model = Seq2Seq(enc, dec, device).to(device)
|
| 110 |
+
|
| 111 |
+
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
|
| 112 |
+
criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"])
|
| 113 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 114 |
+
|
| 115 |
+
# Try to load checkpoint
|
| 116 |
+
start_epoch, best_loss = load_checkpoint(model, optimizer, scaler, cfg.checkpoint_path, device)
|
| 117 |
+
|
| 118 |
+
for epoch in range(start_epoch, cfg.num_epochs):
|
| 119 |
+
print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
loss = train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg)
|
| 123 |
+
print(f"Epoch {epoch+1}/{cfg.num_epochs} | Loss: {loss:.3f}")
|
| 124 |
+
|
| 125 |
+
# Save checkpoint after each epoch
|
| 126 |
+
save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
|
| 127 |
+
|
| 128 |
+
# Save best model
|
| 129 |
+
if loss < best_loss:
|
| 130 |
+
best_loss = loss
|
| 131 |
+
torch.save(model.state_dict(), cfg.best_model_path)
|
| 132 |
+
print(f"🎉 New best model saved with loss: {loss:.3f}")
|
| 133 |
+
|
| 134 |
+
except RuntimeError as e:
|
| 135 |
+
if "CUDA out of memory" in str(e):
|
| 136 |
+
print("⚠️ GPU out of memory. Saving checkpoint and exiting...")
|
| 137 |
+
save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
|
| 138 |
+
print("✅ Checkpoint saved. You can resume training later.")
|
| 139 |
+
break
|
| 140 |
+
else:
|
| 141 |
+
raise e
|
| 142 |
+
|
| 143 |
+
print("✅ Training completed!")
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|