|
|
|
|
|
|
|
|
import json |
|
|
import re |
|
|
|
|
|
import torch |
|
|
|
|
|
from .config import MODEL_PATH, VOCAB_PATH, DEVICE |
|
|
from .dataset import build_edge_index |
|
|
from .preprocess import normalize_teluguish |
|
|
from .model import GNNEncoder, Decoder, Seq2Seq |
|
|
|
|
|
|
|
|
TELUGU_RE = re.compile(r'[\u0C00-\u0C7F]') |
|
|
|
|
|
|
|
|
def contains_telugu(s: str) -> bool: |
|
|
return bool(TELUGU_RE.search(s)) |
|
|
|
|
|
|
|
|
def load_model_and_vocab(): |
|
|
with open(VOCAB_PATH, "r", encoding="utf-8") as f: |
|
|
vocab = json.load(f) |
|
|
|
|
|
src_char2idx = vocab["src_char2idx"] |
|
|
tgt_char2idx = vocab["tgt_char2idx"] |
|
|
word_map = vocab["word_map"] |
|
|
|
|
|
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) |
|
|
|
|
|
encoder = GNNEncoder(src_vocab_size=len(src_char2idx)) |
|
|
decoder = Decoder(tgt_vocab_size=len(tgt_char2idx)) |
|
|
model = Seq2Seq(encoder, decoder, |
|
|
tgt_sos_idx=tgt_char2idx["<sos>"], |
|
|
tgt_eos_idx=tgt_char2idx["<eos>"], |
|
|
tgt_pad_idx=tgt_char2idx["<pad>"]) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
idx2tgt_char = {v: k for k, v in tgt_char2idx.items()} |
|
|
|
|
|
return model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map |
|
|
|
|
|
|
|
|
def encode_src_word_ids(word, src_char2idx, max_len=30): |
|
|
SRC_UNK = src_char2idx["<unk>"] |
|
|
ids = [src_char2idx.get(c, SRC_UNK) for c in word][:max_len] |
|
|
if not ids: |
|
|
ids = [SRC_UNK] |
|
|
return torch.tensor(ids, dtype=torch.long) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def transliterate_word_model(word, model, src_char2idx, idx2tgt_char, max_len=30): |
|
|
""" |
|
|
word: normalized Teluguish |
|
|
returns: Telugu script string |
|
|
""" |
|
|
x = encode_src_word_ids(word, src_char2idx, max_len=max_len) |
|
|
edge_idx = build_edge_index(x.size(0)) |
|
|
batch_vec = torch.zeros(x.size(0), dtype=torch.long) |
|
|
|
|
|
x = x.to(DEVICE) |
|
|
edge_idx = edge_idx.to(DEVICE) |
|
|
batch_vec = batch_vec.to(DEVICE) |
|
|
|
|
|
encoder_outputs, src_mask = model.encoder(x, edge_idx, batch_vec) |
|
|
B, S, H = encoder_outputs.size() |
|
|
|
|
|
hidden = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE) |
|
|
cell = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE) |
|
|
|
|
|
TGT_SOS = model.tgt_sos_idx |
|
|
TGT_EOS = model.tgt_eos_idx |
|
|
|
|
|
input_token = torch.tensor([TGT_SOS], device=DEVICE) |
|
|
|
|
|
out_chars = [] |
|
|
for _ in range(max_len): |
|
|
logits, hidden, cell, _ = model.decoder(input_token, hidden, cell, encoder_outputs, src_mask) |
|
|
pred_id = logits.argmax(dim=-1).item() |
|
|
if pred_id == TGT_EOS: |
|
|
break |
|
|
ch = idx2tgt_char.get(pred_id, "") |
|
|
out_chars.append(ch) |
|
|
input_token = torch.tensor([pred_id], device=DEVICE) |
|
|
|
|
|
return "".join(out_chars) |
|
|
|
|
|
|
|
|
def transliterate_sentence(sentence: str, model, src_char2idx, idx2tgt_char, word_map): |
|
|
tokens = sentence.strip().split() |
|
|
out_tokens = [] |
|
|
for w in tokens: |
|
|
if contains_telugu(w): |
|
|
out_tokens.append(w) |
|
|
continue |
|
|
w_norm = normalize_teluguish(w) |
|
|
if w_norm in word_map: |
|
|
out_tokens.append(word_map[w_norm]) |
|
|
else: |
|
|
tel = transliterate_word_model(w_norm, model, src_char2idx, idx2tgt_char) |
|
|
out_tokens.append(tel if tel else w) |
|
|
return " ".join(out_tokens) |
|
|
|
|
|
|
|
|
def demo(): |
|
|
model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map = load_model_and_vocab() |
|
|
|
|
|
while True: |
|
|
s = input("Teluguish (or mixed) > ").strip() |
|
|
if not s: |
|
|
continue |
|
|
if s.lower() in ("q", "quit", "exit"): |
|
|
break |
|
|
out = transliterate_sentence(s, model, src_char2idx, idx2tgt_char, word_map) |
|
|
print("TELUGU:", out) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo() |
|
|
|