File size: 3,689 Bytes
beddc0e dc29eb1 beddc0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
# src/inference.py
import json
import re
import torch
from .config import MODEL_PATH, VOCAB_PATH, DEVICE
from .dataset import build_edge_index
from .preprocess import normalize_teluguish
from .model import GNNEncoder, Decoder, Seq2Seq
TELUGU_RE = re.compile(r'[\u0C00-\u0C7F]')
def contains_telugu(s: str) -> bool:
return bool(TELUGU_RE.search(s))
def load_model_and_vocab():
with open(VOCAB_PATH, "r", encoding="utf-8") as f:
vocab = json.load(f)
src_char2idx = vocab["src_char2idx"]
tgt_char2idx = vocab["tgt_char2idx"]
word_map = vocab["word_map"]
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
encoder = GNNEncoder(src_vocab_size=len(src_char2idx))
decoder = Decoder(tgt_vocab_size=len(tgt_char2idx))
model = Seq2Seq(encoder, decoder,
tgt_sos_idx=tgt_char2idx["<sos>"],
tgt_eos_idx=tgt_char2idx["<eos>"],
tgt_pad_idx=tgt_char2idx["<pad>"])
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE)
model.eval()
idx2tgt_char = {v: k for k, v in tgt_char2idx.items()}
return model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map
def encode_src_word_ids(word, src_char2idx, max_len=30):
SRC_UNK = src_char2idx["<unk>"]
ids = [src_char2idx.get(c, SRC_UNK) for c in word][:max_len]
if not ids:
ids = [SRC_UNK]
return torch.tensor(ids, dtype=torch.long)
@torch.no_grad()
def transliterate_word_model(word, model, src_char2idx, idx2tgt_char, max_len=30):
"""
word: normalized Teluguish
returns: Telugu script string
"""
x = encode_src_word_ids(word, src_char2idx, max_len=max_len) # [N]
edge_idx = build_edge_index(x.size(0))
batch_vec = torch.zeros(x.size(0), dtype=torch.long)
x = x.to(DEVICE)
edge_idx = edge_idx.to(DEVICE)
batch_vec = batch_vec.to(DEVICE)
encoder_outputs, src_mask = model.encoder(x, edge_idx, batch_vec) # [1,S,H], [1,S]
B, S, H = encoder_outputs.size()
hidden = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE)
cell = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE)
TGT_SOS = model.tgt_sos_idx
TGT_EOS = model.tgt_eos_idx
input_token = torch.tensor([TGT_SOS], device=DEVICE)
out_chars = []
for _ in range(max_len):
logits, hidden, cell, _ = model.decoder(input_token, hidden, cell, encoder_outputs, src_mask)
pred_id = logits.argmax(dim=-1).item()
if pred_id == TGT_EOS:
break
ch = idx2tgt_char.get(pred_id, "")
out_chars.append(ch)
input_token = torch.tensor([pred_id], device=DEVICE)
return "".join(out_chars)
def transliterate_sentence(sentence: str, model, src_char2idx, idx2tgt_char, word_map):
tokens = sentence.strip().split()
out_tokens = []
for w in tokens:
if contains_telugu(w):
out_tokens.append(w)
continue
w_norm = normalize_teluguish(w)
if w_norm in word_map:
out_tokens.append(word_map[w_norm])
else:
tel = transliterate_word_model(w_norm, model, src_char2idx, idx2tgt_char)
out_tokens.append(tel if tel else w)
return " ".join(out_tokens)
def demo():
model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map = load_model_and_vocab()
while True:
s = input("Teluguish (or mixed) > ").strip()
if not s:
continue
if s.lower() in ("q", "quit", "exit"):
break
out = transliterate_sentence(s, model, src_char2idx, idx2tgt_char, word_map)
print("TELUGU:", out)
if __name__ == "__main__":
demo()
|