samyhusy commited on
Commit
5904988
·
verified ·
1 Parent(s): bc56659

Upload 6 files

Browse files
Files changed (6) hide show
  1. app_2.py +130 -0
  2. config.py +42 -0
  3. inference.py +37 -0
  4. requirements.txt +6 -0
  5. resume_training.py +14 -0
  6. train.py +146 -0
app_2.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+ # Load the model and tokenizer
6
+ model_name = "jbochi/madlad400-3b-mt"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(
9
+ model_name,
10
+ torch_dtype=torch.float16,
11
+ device_map="auto"
12
+ )
13
+
14
+ def translate_text(text, source_lang, target_lang):
15
+ """
16
+ Translate text between English and Persian using MADLAD-400-3B
17
+ """
18
+ # Define language codes for the model
19
+ lang_codes = {
20
+ "English": "en",
21
+ "Persian": "fa"
22
+ }
23
+
24
+ source_code = lang_codes[source_lang]
25
+ target_code = lang_codes[target_lang]
26
+
27
+ # Create the translation prompt in the format the model expects
28
+ prompt = f"<2{target_code}> {text}"
29
+
30
+ try:
31
+ # Tokenize input
32
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
33
+
34
+ # Move inputs to the same device as model
35
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
36
+
37
+ # Generate translation
38
+ with torch.no_grad():
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_length=512,
42
+ num_beams=5,
43
+ early_stopping=True,
44
+ no_repeat_ngram_size=3,
45
+ length_penalty=1.0
46
+ )
47
+
48
+ # Decode the output
49
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+
51
+ return translated_text
52
+
53
+ except Exception as e:
54
+ return f"Error during translation: {str(e)}"
55
+
56
+ # Create the Gradio interface
57
+ with gr.Blocks(title="English-Persian Translator") as demo:
58
+ gr.Markdown(
59
+ """
60
+ # 🌍 English-Persian Translator
61
+ **Powered by MADLAD-400-3B Model**
62
+
63
+ Translate text between English and Persian using the state-of-the-art MADLAD-400 model.
64
+ """
65
+ )
66
+
67
+ with gr.Row():
68
+ with gr.Column():
69
+ source_lang = gr.Dropdown(
70
+ choices=["English", "Persian"],
71
+ value="English",
72
+ label="Source Language"
73
+ )
74
+ input_text = gr.Textbox(
75
+ lines=5,
76
+ placeholder="Enter text to translate...",
77
+ label="Input Text"
78
+ )
79
+ translate_btn = gr.Button("Translate", variant="primary")
80
+
81
+ with gr.Column():
82
+ target_lang = gr.Dropdown(
83
+ choices=["Persian", "English"],
84
+ value="Persian",
85
+ label="Target Language"
86
+ )
87
+ output_text = gr.Textbox(
88
+ lines=5,
89
+ label="Translated Text",
90
+ interactive=False
91
+ )
92
+
93
+ # Examples
94
+ gr.Examples(
95
+ examples=[
96
+ ["Hello, how are you today?", "English", "Persian"],
97
+ ["What is your name?", "English", "Persian"],
98
+ ["سلام، حالتون چطوره؟", "Persian", "English"],
99
+ ["امروز هوا خوب است", "Persian", "English"]
100
+ ],
101
+ inputs=[input_text, source_lang, target_lang],
102
+ outputs=output_text,
103
+ fn=translate_text,
104
+ cache_examples=False
105
+ )
106
+
107
+ # Connect the button
108
+ translate_btn.click(
109
+ fn=translate_text,
110
+ inputs=[input_text, source_lang, target_lang],
111
+ outputs=output_text
112
+ )
113
+
114
+ # Auto-update target language based on source selection
115
+ def update_target_lang(source_lang):
116
+ return "Persian" if source_lang == "English" else "English"
117
+
118
+ source_lang.change(
119
+ fn=update_target_lang,
120
+ inputs=source_lang,
121
+ outputs=target_lang
122
+ )
123
+
124
+ if __name__ == "__main__":
125
+ # Launch the app
126
+ demo.launch(
127
+ server_name="0.0.0.0", # Allow external access
128
+ share=False, # Set to True to get a public URL
129
+ debug=True
130
+ )
config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ class Config:
3
+ # Data
4
+ dataset_name = "ParsBench/parsinlu-machine-translation-fa-en-alpaca-style"
5
+ source_lang = "instruction" # English
6
+ target_lang = "output" # Persian
7
+ max_length = 32
8
+ batch_size = 24
9
+
10
+ # Model
11
+ input_dim = 5000 # Vocabulary size for English
12
+ output_dim = 5000 # Vocabulary size for Persian
13
+ embedding_dim = 64 # Word vector dimensions
14
+ hidden_dim = 128 # LSTM hidden state size
15
+ num_layers = 1 # Stacked LSTM layers
16
+ dropout = 0.1 # Regularization to prevent overfitting
17
+
18
+ # Training
19
+ learning_rate = 0.001
20
+ num_epochs = 5
21
+ teacher_forcing_ratio = 0.7 # Mix of ground truth vs model predictions
22
+
23
+
24
+ # Optimization
25
+ gradient_accumulation_steps = 1
26
+ use_amp = True # Mixed precision for speed
27
+ use_gradient_clipping = True
28
+ max_grad_norm = 1.0
29
+
30
+
31
+ # Checkpoint Configuration =====
32
+ checkpoint_interval = 1 # Save every 2 epochs
33
+ save_best_only = True # Only save when model improves
34
+
35
+ # Device
36
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
+
38
+ # Paths
39
+ model_save_path = "models/seq2seq_model.pth"
40
+ tokenizer_save_path = "models/tokenizers/"
41
+ checkpoint_path = "models/checkpoint.pth"
42
+ best_model_path = "models/best_model.pth"
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import Config
3
+ from models.seq2seq import Encoder, Decoder, Seq2Seq
4
+ from utils.tokenizer import build_vocab
5
+ from datasets import load_from_disk
6
+
7
+ def translate_sentence(sentence, model, src_tokenizer, src_vocab, trg_vocab, device, max_len=30):
8
+ model.eval()
9
+ tokens = src_tokenizer(sentence.lower())
10
+ src_tensor = torch.tensor([src_vocab["<sos>"]] + [src_vocab[t] for t in tokens] + [src_vocab["<eos>"]]).unsqueeze(1).to(device)
11
+ with torch.no_grad():
12
+ hidden = model.encoder(src_tensor)
13
+ trg_indexes = [trg_vocab["<sos>"]]
14
+
15
+ for _ in range(max_len):
16
+ trg_tensor = torch.tensor([trg_indexes[-1]]).to(device)
17
+ with torch.no_grad():
18
+ output, hidden = model.decoder(trg_tensor, hidden)
19
+ pred_token = output.argmax(1).item()
20
+ trg_indexes.append(pred_token)
21
+ if pred_token == trg_vocab["<eos>"]:
22
+ break
23
+ return [trg_vocab.get_itos()[i] for i in trg_indexes]
24
+
25
+ if __name__ == "__main__":
26
+ cfg = Config()
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ dataset = load_from_disk("data/raw/")
29
+ src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang)
30
+ trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang)
31
+
32
+ enc = Encoder(len(src_vocab), cfg.emb_dim, cfg.hid_dim, cfg.n_layers)
33
+ dec = Decoder(len(trg_vocab), cfg.emb_dim, cfg.hid_dim, cfg.n_layers)
34
+ model = Seq2Seq(enc, dec, device).to(device)
35
+ model.load_state_dict(torch.load(cfg.model_save_path, map_location=device))
36
+
37
+ print(translate_sentence("I love cats", model, src_tokenizer, src_vocab, trg_vocab, device))
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchtext>=0.10.0
3
+ datasets>=2.14.0
4
+ numpy>=1.21.0
5
+ tqdm>=4.62.0
6
+ streamlit>=1.22.0
resume_training.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from train import main
3
+ import os
4
+
5
+ if __name__ == "__main__":
6
+ print("🔄 Resuming training from checkpoint...")
7
+
8
+ # Check if checkpoint exists
9
+ if not os.path.exists("models/checkpoint.pth"):
10
+ print("❌ No checkpoint found. Starting fresh training...")
11
+ else:
12
+ print("✅ Checkpoint found. Resuming...")
13
+
14
+ main()
train.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ from datasets import load_from_disk
6
+ import os
7
+ from config import Config
8
+ from utils.tokenizer import build_vocab
9
+ from utils.preprocessing import collate_fn
10
+ from models.seq2seq import Encoder, Decoder, Seq2Seq
11
+ from tqdm import tqdm
12
+
13
+ def save_checkpoint(epoch, model, optimizer, scaler, loss, path):
14
+ """Save training checkpoint"""
15
+ checkpoint = {
16
+ 'epoch': epoch,
17
+ 'model_state_dict': model.state_dict(),
18
+ 'optimizer_state_dict': optimizer.state_dict(),
19
+ 'scaler_state_dict': scaler.state_dict(),
20
+ 'loss': loss,
21
+ }
22
+ torch.save(checkpoint, path)
23
+ print(f"✅ Checkpoint saved at epoch {epoch}")
24
+
25
+ def load_checkpoint(model, optimizer, scaler, path, device):
26
+ """Load training checkpoint"""
27
+ if os.path.exists(path):
28
+ checkpoint = torch.load(path, map_location=device)
29
+ model.load_state_dict(checkpoint['model_state_dict'])
30
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
31
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
32
+ start_epoch = checkpoint['epoch'] + 1
33
+ best_loss = checkpoint['loss']
34
+ print(f"✅ Checkpoint loaded. Resuming from epoch {start_epoch}")
35
+ return start_epoch, best_loss
36
+ return 0, float('inf') # Start from beginning if no checkpoint
37
+
38
+ def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg):
39
+ model.train()
40
+ total_loss = 0
41
+ optimizer.zero_grad() # Zero gradients at start
42
+
43
+ loop = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
44
+
45
+ for batch_idx, (src, trg) in enumerate(loop):
46
+ src, trg = src.to(device), trg.to(device)
47
+
48
+ # Mixed precision training
49
+ with torch.cuda.amp.autocast(enabled=cfg.use_amp):
50
+ output = model(src, trg)
51
+ output_dim = output.shape[-1]
52
+ output = output[1:].reshape(-1, output_dim)
53
+ trg = trg[1:].reshape(-1)
54
+ loss = criterion(output, trg) / cfg.gradient_accumulation_steps # Normalize loss
55
+
56
+ scaler.scale(loss).backward()
57
+
58
+ # Gradient accumulation
59
+ if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0:
60
+ if cfg.use_gradient_clipping:
61
+ scaler.unscale_(optimizer)
62
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
63
+
64
+ scaler.step(optimizer)
65
+ scaler.update()
66
+ optimizer.zero_grad()
67
+
68
+ total_loss += loss.item() * cfg.gradient_accumulation_steps
69
+ loop.set_postfix(loss=loss.item() * cfg.gradient_accumulation_steps)
70
+
71
+ return total_loss / len(dataloader)
72
+
73
+
74
+
75
+ def main():
76
+ cfg = Config()
77
+ device = cfg.device
78
+ print(f"Using device: {device}")
79
+
80
+ # Create directories if they don't exist
81
+ os.makedirs("models", exist_ok=True)
82
+ os.makedirs("models/tokenizers", exist_ok=True)
83
+
84
+ # Load dataset (full dataset)
85
+ dataset = load_from_disk("data/raw/")
86
+
87
+ # Build vocab using full dataset
88
+ src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang)
89
+ trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang)
90
+
91
+ # Save tokenizers and vocab for future use
92
+ torch.save({
93
+ 'src_tokenizer': src_tokenizer,
94
+ 'src_vocab': src_vocab,
95
+ 'trg_tokenizer': trg_tokenizer,
96
+ 'trg_vocab': trg_vocab
97
+ }, cfg.tokenizer_save_path + "tokenizers.pth")
98
+
99
+ # DataLoader with train split
100
+ collate = lambda batch: collate_fn(
101
+ batch, src_tokenizer, trg_tokenizer, src_vocab, trg_vocab, cfg.max_length,
102
+ src_lang=cfg.source_lang, trg_lang=cfg.target_lang
103
+ )
104
+ dataloader = DataLoader(dataset["train"], batch_size=cfg.batch_size, collate_fn=collate, shuffle=True)
105
+
106
+ # Model
107
+ enc = Encoder(len(src_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
108
+ dec = Decoder(len(trg_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers)
109
+ model = Seq2Seq(enc, dec, device).to(device)
110
+
111
+ optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)
112
+ criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"])
113
+ scaler = torch.cuda.amp.GradScaler()
114
+
115
+ # Try to load checkpoint
116
+ start_epoch, best_loss = load_checkpoint(model, optimizer, scaler, cfg.checkpoint_path, device)
117
+
118
+ for epoch in range(start_epoch, cfg.num_epochs):
119
+ print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
120
+
121
+ try:
122
+ loss = train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg)
123
+ print(f"Epoch {epoch+1}/{cfg.num_epochs} | Loss: {loss:.3f}")
124
+
125
+ # Save checkpoint after each epoch
126
+ save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
127
+
128
+ # Save best model
129
+ if loss < best_loss:
130
+ best_loss = loss
131
+ torch.save(model.state_dict(), cfg.best_model_path)
132
+ print(f"🎉 New best model saved with loss: {loss:.3f}")
133
+
134
+ except RuntimeError as e:
135
+ if "CUDA out of memory" in str(e):
136
+ print("⚠️ GPU out of memory. Saving checkpoint and exiting...")
137
+ save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path)
138
+ print("✅ Checkpoint saved. You can resume training later.")
139
+ break
140
+ else:
141
+ raise e
142
+
143
+ print("✅ Training completed!")
144
+
145
+ if __name__ == "__main__":
146
+ main()