# src/inference.py import json import re import torch from .config import MODEL_PATH, VOCAB_PATH, DEVICE from .dataset import build_edge_index from .preprocess import normalize_teluguish from .model import GNNEncoder, Decoder, Seq2Seq TELUGU_RE = re.compile(r'[\u0C00-\u0C7F]') def contains_telugu(s: str) -> bool: return bool(TELUGU_RE.search(s)) def load_model_and_vocab(): with open(VOCAB_PATH, "r", encoding="utf-8") as f: vocab = json.load(f) src_char2idx = vocab["src_char2idx"] tgt_char2idx = vocab["tgt_char2idx"] word_map = vocab["word_map"] checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) encoder = GNNEncoder(src_vocab_size=len(src_char2idx)) decoder = Decoder(tgt_vocab_size=len(tgt_char2idx)) model = Seq2Seq(encoder, decoder, tgt_sos_idx=tgt_char2idx[""], tgt_eos_idx=tgt_char2idx[""], tgt_pad_idx=tgt_char2idx[""]) model.load_state_dict(checkpoint["model_state_dict"]) model.to(DEVICE) model.eval() idx2tgt_char = {v: k for k, v in tgt_char2idx.items()} return model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map def encode_src_word_ids(word, src_char2idx, max_len=30): SRC_UNK = src_char2idx[""] ids = [src_char2idx.get(c, SRC_UNK) for c in word][:max_len] if not ids: ids = [SRC_UNK] return torch.tensor(ids, dtype=torch.long) @torch.no_grad() def transliterate_word_model(word, model, src_char2idx, idx2tgt_char, max_len=30): """ word: normalized Teluguish returns: Telugu script string """ x = encode_src_word_ids(word, src_char2idx, max_len=max_len) # [N] edge_idx = build_edge_index(x.size(0)) batch_vec = torch.zeros(x.size(0), dtype=torch.long) x = x.to(DEVICE) edge_idx = edge_idx.to(DEVICE) batch_vec = batch_vec.to(DEVICE) encoder_outputs, src_mask = model.encoder(x, edge_idx, batch_vec) # [1,S,H], [1,S] B, S, H = encoder_outputs.size() hidden = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE) cell = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE) TGT_SOS = model.tgt_sos_idx TGT_EOS = model.tgt_eos_idx input_token = torch.tensor([TGT_SOS], device=DEVICE) out_chars = [] for _ in range(max_len): logits, hidden, cell, _ = model.decoder(input_token, hidden, cell, encoder_outputs, src_mask) pred_id = logits.argmax(dim=-1).item() if pred_id == TGT_EOS: break ch = idx2tgt_char.get(pred_id, "") out_chars.append(ch) input_token = torch.tensor([pred_id], device=DEVICE) return "".join(out_chars) def transliterate_sentence(sentence: str, model, src_char2idx, idx2tgt_char, word_map): tokens = sentence.strip().split() out_tokens = [] for w in tokens: if contains_telugu(w): out_tokens.append(w) continue w_norm = normalize_teluguish(w) if w_norm in word_map: out_tokens.append(word_map[w_norm]) else: tel = transliterate_word_model(w_norm, model, src_char2idx, idx2tgt_char) out_tokens.append(tel if tel else w) return " ".join(out_tokens) def demo(): model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map = load_model_and_vocab() while True: s = input("Teluguish (or mixed) > ").strip() if not s: continue if s.lower() in ("q", "quit", "exit"): break out = transliterate_sentence(s, model, src_char2idx, idx2tgt_char, word_map) print("TELUGU:", out) if __name__ == "__main__": demo()