File size: 3,689 Bytes
beddc0e
 
 
 
 
 
 
dc29eb1
 
 
 
beddc0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# src/inference.py

import json
import re

import torch

from .config import MODEL_PATH, VOCAB_PATH, DEVICE
from .dataset import build_edge_index
from .preprocess import normalize_teluguish
from .model import GNNEncoder, Decoder, Seq2Seq


TELUGU_RE = re.compile(r'[\u0C00-\u0C7F]')


def contains_telugu(s: str) -> bool:
    return bool(TELUGU_RE.search(s))


def load_model_and_vocab():
    with open(VOCAB_PATH, "r", encoding="utf-8") as f:
        vocab = json.load(f)

    src_char2idx = vocab["src_char2idx"]
    tgt_char2idx = vocab["tgt_char2idx"]
    word_map = vocab["word_map"]

    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)

    encoder = GNNEncoder(src_vocab_size=len(src_char2idx))
    decoder = Decoder(tgt_vocab_size=len(tgt_char2idx))
    model = Seq2Seq(encoder, decoder,
                    tgt_sos_idx=tgt_char2idx["<sos>"],
                    tgt_eos_idx=tgt_char2idx["<eos>"],
                    tgt_pad_idx=tgt_char2idx["<pad>"])
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(DEVICE)
    model.eval()

    idx2tgt_char = {v: k for k, v in tgt_char2idx.items()}

    return model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map


def encode_src_word_ids(word, src_char2idx, max_len=30):
    SRC_UNK = src_char2idx["<unk>"]
    ids = [src_char2idx.get(c, SRC_UNK) for c in word][:max_len]
    if not ids:
        ids = [SRC_UNK]
    return torch.tensor(ids, dtype=torch.long)


@torch.no_grad()
def transliterate_word_model(word, model, src_char2idx, idx2tgt_char, max_len=30):
    """
    word: normalized Teluguish
    returns: Telugu script string
    """
    x = encode_src_word_ids(word, src_char2idx, max_len=max_len)  # [N]
    edge_idx = build_edge_index(x.size(0))
    batch_vec = torch.zeros(x.size(0), dtype=torch.long)

    x = x.to(DEVICE)
    edge_idx = edge_idx.to(DEVICE)
    batch_vec = batch_vec.to(DEVICE)

    encoder_outputs, src_mask = model.encoder(x, edge_idx, batch_vec)  # [1,S,H], [1,S]
    B, S, H = encoder_outputs.size()

    hidden = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE)
    cell = torch.zeros(1, B, model.decoder.rnn.hidden_size, device=DEVICE)

    TGT_SOS = model.tgt_sos_idx
    TGT_EOS = model.tgt_eos_idx

    input_token = torch.tensor([TGT_SOS], device=DEVICE)

    out_chars = []
    for _ in range(max_len):
        logits, hidden, cell, _ = model.decoder(input_token, hidden, cell, encoder_outputs, src_mask)
        pred_id = logits.argmax(dim=-1).item()
        if pred_id == TGT_EOS:
            break
        ch = idx2tgt_char.get(pred_id, "")
        out_chars.append(ch)
        input_token = torch.tensor([pred_id], device=DEVICE)

    return "".join(out_chars)


def transliterate_sentence(sentence: str, model, src_char2idx, idx2tgt_char, word_map):
    tokens = sentence.strip().split()
    out_tokens = []
    for w in tokens:
        if contains_telugu(w):
            out_tokens.append(w)
            continue
        w_norm = normalize_teluguish(w)
        if w_norm in word_map:
            out_tokens.append(word_map[w_norm])
        else:
            tel = transliterate_word_model(w_norm, model, src_char2idx, idx2tgt_char)
            out_tokens.append(tel if tel else w)
    return " ".join(out_tokens)


def demo():
    model, src_char2idx, tgt_char2idx, idx2tgt_char, word_map = load_model_and_vocab()

    while True:
        s = input("Teluguish (or mixed) > ").strip()
        if not s:
            continue
        if s.lower() in ("q", "quit", "exit"):
            break
        out = transliterate_sentence(s, model, src_char2idx, idx2tgt_char, word_map)
        print("TELUGU:", out)


if __name__ == "__main__":
    demo()