import os import json import csv import time import datetime import pandas as pd import torch from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim from tqdm import tqdm from .config import ( DATA_PATH, VOCAB_PATH, MODEL_PATH, DEVICE, BATCH_SIZE, NUM_EPOCHS, LR, TEACHER_FORCING, ) from .preprocess import build_word_pairs from .dataset import WordTranslitDataset, collate_fn from .model import GNNEncoder, Decoder, Seq2Seq def load_vocab(): with open(VOCAB_PATH, "r", encoding="utf-8") as f: vocab = json.load(f) return vocab def setup_logging(): """Create logs folder and CSV + TXT log files.""" os.makedirs("logs", exist_ok=True) txt_log_path = "logs/train_log.txt" csv_log_path = "logs/train_metrics.csv" # Create CSV header if not exists if not os.path.exists(csv_log_path): with open(csv_log_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["epoch", "token_loss", "batch_size", "learning_rate", "timestamp"]) return txt_log_path, csv_log_path def append_txt_log(path, message): with open(path, "a", encoding="utf-8") as f: f.write(message + "\n") def append_csv_log(path, epoch, loss, lr): with open(path, "a", newline="") as f: writer = csv.writer(f) writer.writerow( [ epoch, loss, BATCH_SIZE, lr, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ] ) def main(): # make sure saved/ dir exists for model os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) # logging setup txt_log, csv_log = setup_logging() append_txt_log(txt_log, "\n========================") append_txt_log(txt_log, f"TRAINING STARTED AT {datetime.datetime.now()}") append_txt_log(txt_log, "========================\n") # load vocab + dictionary vocab = load_vocab() src_char2idx = vocab["src_char2idx"] tgt_char2idx = vocab["tgt_char2idx"] TGT_PAD = tgt_char2idx[""] TGT_SOS = tgt_char2idx[""] TGT_EOS = tgt_char2idx[""] # load your parquet data df = pd.read_parquet(DATA_PATH) # build word pairs (teluguish, telugu) using your helper src_words, tgt_words = build_word_pairs(df) dataset = WordTranslitDataset(src_words, tgt_words, src_char2idx, tgt_char2idx) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, ) # create model encoder = GNNEncoder(src_vocab_size=len(src_char2idx)).to(DEVICE) decoder = Decoder(tgt_vocab_size=len(tgt_char2idx)).to(DEVICE) model = Seq2Seq(encoder, decoder, TGT_SOS, TGT_EOS, TGT_PAD).to(DEVICE) criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD) optimizer = optim.Adam(model.parameters(), lr=LR) best_loss = float("inf") for epoch in range(1, NUM_EPOCHS + 1): model.train() epoch_loss = 0.0 total_tokens = 0 start_time = time.time() progress_bar = tqdm( enumerate(dataloader, start=1), total=len(dataloader), desc=f"Epoch {epoch}/{NUM_EPOCHS}", leave=True, ncols=100, dynamic_ncols=True, ) for batch_idx, (x, edge_idx, batch_vec, tgt_padded, tgt_lengths) in progress_bar: x = x.to(DEVICE) edge_idx = edge_idx.to(DEVICE) batch_vec = batch_vec.to(DEVICE) tgt_padded = tgt_padded.to(DEVICE) optimizer.zero_grad() outputs = model( x, edge_idx, batch_vec, tgt_padded, teacher_forcing_ratio=TEACHER_FORCING, ) B, T, V = outputs.size() # ignore t=0 () for loss logits_flat = outputs[:, 1:, :].reshape(-1, V) tgt_flat = tgt_padded[:, 1:].reshape(-1) loss = criterion(logits_flat, tgt_flat) loss.backward() optimizer.step() tokens = tgt_flat.ne(TGT_PAD).sum().item() epoch_loss += loss.item() * tokens total_tokens += tokens # update tqdm with batch loss only every few steps to avoid flicker if batch_idx % 50 == 0 or batch_idx == len(dataloader): progress_bar.set_postfix({"batch_loss": f"{loss.item():.4f}"}) avg_loss = epoch_loss / max(total_tokens, 1) lr = optimizer.param_groups[0]["lr"] epoch_msg = ( f"Epoch {epoch}/{NUM_EPOCHS} | " f"Loss: {avg_loss:.4f} | " f"LR: {lr:.6f} | " f"Time: {time.time() - start_time:.2f}s" ) print(epoch_msg) append_txt_log(txt_log, epoch_msg) append_csv_log(csv_log, epoch, avg_loss, lr) # Save best model checkpoint if avg_loss < best_loss: best_loss = avg_loss torch.save( { "model_state_dict": model.state_dict(), "src_char2idx": src_char2idx, "tgt_char2idx": tgt_char2idx, }, MODEL_PATH, ) save_msg = f"BEST MODEL SAVED (epoch {epoch}, loss {avg_loss:.4f})" print(save_msg) append_txt_log(txt_log, save_msg) append_txt_log(txt_log, "\nTraining completed.\n") # def main(): # # make sure saved/ dir exists for model # os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) # # logging setup # txt_log, csv_log = setup_logging() # append_txt_log(txt_log, "\n========================") # append_txt_log(txt_log, f"TRAINING STARTED AT {datetime.datetime.now()}") # append_txt_log(txt_log, "========================\n") # # load vocab + dictionary # vocab = load_vocab() # src_char2idx = vocab["src_char2idx"] # tgt_char2idx = vocab["tgt_char2idx"] # TGT_PAD = tgt_char2idx[""] # TGT_SOS = tgt_char2idx[""] # TGT_EOS = tgt_char2idx[""] # # load your parquet data # df = pd.read_parquet(DATA_PATH) # # build word pairs (teluguish, telugu) using your helper # src_words, tgt_words = build_word_pairs(df) # dataset = WordTranslitDataset(src_words, tgt_words, src_char2idx, tgt_char2idx) # dataloader = DataLoader( # dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn # ) # # create model # encoder = GNNEncoder(src_vocab_size=len(src_char2idx)).to(DEVICE) # decoder = Decoder(tgt_vocab_size=len(tgt_char2idx)).to(DEVICE) # model = Seq2Seq(encoder, decoder, TGT_SOS, TGT_EOS, TGT_PAD).to(DEVICE) # criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD) # optimizer = optim.Adam(model.parameters(), lr=LR) # best_loss = float("inf") # for epoch in range(1, NUM_EPOCHS + 1): # model.train() # epoch_loss = 0.0 # total_tokens = 0 # start_time = time.time() # # tqdm progress bar over batches # progress_bar = tqdm( # dataloader, # desc=f"Epoch {epoch}/{NUM_EPOCHS}", # leave=True, # ncols=100, # colour="green", # ) # for x, edge_idx, batch_vec, tgt_padded, tgt_lengths in progress_bar: # x = x.to(DEVICE) # edge_idx = edge_idx.to(DEVICE) # batch_vec = batch_vec.to(DEVICE) # tgt_padded = tgt_padded.to(DEVICE) # optimizer.zero_grad() # outputs = model( # x, # edge_idx, # batch_vec, # tgt_padded, # teacher_forcing_ratio=TEACHER_FORCING, # ) # B, T, V = outputs.size() # # ignore t=0 () for loss # logits_flat = outputs[:, 1:, :].reshape(-1, V) # tgt_flat = tgt_padded[:, 1:].reshape(-1) # loss = criterion(logits_flat, tgt_flat) # loss.backward() # optimizer.step() # tokens = tgt_flat.ne(TGT_PAD).sum().item() # epoch_loss += loss.item() * tokens # total_tokens += tokens # # update tqdm with batch loss # progress_bar.set_postfix({"batch_loss": f"{loss.item():.4f}"}) # avg_loss = epoch_loss / max(total_tokens, 1) # lr = optimizer.param_groups[0]["lr"] # epoch_msg = ( # f"Epoch {epoch}/{NUM_EPOCHS} | " # f"Loss: {avg_loss:.4f} | " # f"LR: {lr:.6f} | " # f"Time: {time.time() - start_time:.2f}s" # ) # append_txt_log(txt_log, epoch_msg) # append_csv_log(csv_log, epoch, avg_loss, lr) # # Save best model checkpoint # if avg_loss < best_loss: # best_loss = avg_loss # torch.save( # { # "model_state_dict": model.state_dict(), # "src_char2idx": src_char2idx, # "tgt_char2idx": tgt_char2idx, # }, # MODEL_PATH, # ) # save_msg = f"BEST MODEL SAVED (epoch {epoch}, loss {avg_loss:.4f})" # print(save_msg) # append_txt_log(txt_log, save_msg) # append_txt_log(txt_log, "\nTraining completed.\n") if __name__ == "__main__": main()