File size: 4,307 Bytes
c856b80
 
 
8499f35
c856b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8499f35
 
c856b80
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
import re
import setup 
from model import create_model
from dataset import TeluguishDataset, Vocabulary  # Required for loading the saved vocab objects
from torch_geometric.data import Data

# --- 1. GLOBAL SETUP: Load the model and vocabularies only ONCE when the app starts ---
print("--- Initializing Gradio Application ---")

# Set the device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load the vocabularies from the dataset's processed file
print("Loading vocabularies...")
DATASET = TeluguishDataset(root='.')
SRC_VOCAB = DATASET.src_vocab
TGT_VOCAB = DATASET.tgt_vocab
print("Vocabularies loaded.")

# Load the trained model
print("Loading trained model...")
# Create the model structure with the correct vocabulary sizes
MODEL = create_model(len(SRC_VOCAB), len(TGT_VOCAB))
# Load the saved weights into the model structure
MODEL.load_state_dict(torch.load(setup.MODEL_PATH, map_location=DEVICE))
MODEL.to(DEVICE)
MODEL.eval()  # Set the model to evaluation mode permanently
print("Model loaded. Application is ready to launch.")
# -----------------------------------------------------------------------------------------


# --- 2. CORE LOGIC: The helper functions needed for inference ---

def create_graph_from_text(text, vocab):
    """
    Converts a single Teluguish string into a PyG Data object for inference.
    """
    tokens = list(text)
    node_features = torch.tensor([vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long)
    
    edge_list = []
    if len(tokens) > 1:
        for i in range(len(tokens) - 1):
            edge_list.extend([[i, i + 1], [i + 1, i]])  # Bidirectional edges
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    return Data(x=node_features, edge_index=edge_index)

def transliterate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=100):
    """
    Performs the transliteration for a single sentence.
    """
    # Preprocess the input sentence
    sentence = sentence.lower().strip()
    sentence = re.sub(r'[^a-z\s]', '', sentence)
    
    if not sentence:
        return ""

    # Create a graph from the input sentence
    src_graph = create_graph_from_text(sentence, src_vocab)
    src_graph = src_graph.to(device)
    
    src_batch_vector = torch.zeros(src_graph.x.shape[0], dtype=torch.long).to(device)
    
    with torch.no_grad():
        _, graph_embedding = model.encoder(src_graph.x, src_graph.edge_index, src_batch_vector)
        hidden = graph_embedding.unsqueeze(0)
        cell = torch.zeros(1, 1, setup.HIDDEN_DIM).to(device)

    trg_indexes = [setup.SOS_TOKEN]
    
    for _ in range(max_len):
        trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
        with torch.no_grad():
            output, hidden, cell = model.decoder(trg_tensor, hidden, cell, graph_embedding)
        
        pred_token = output.argmax(1).item()
        trg_indexes.append(pred_token)
        
        if pred_token == setup.EOS_TOKEN:
            break
            
    trg_tokens = [tgt_vocab.itos[i] for i in trg_indexes]
    return "".join(trg_tokens[1:-1])


# --- 3. GRADIO FUNCTION: The main function that the UI will call ---

def predict(teluguish_text):
    """
    The main prediction function for the Gradio interface.
    It uses the globally loaded model and vocabularies.
    """
    if not teluguish_text:
        return ""
    # Call the core logic function to get the transliteration
    return transliterate_sentence(MODEL, teluguish_text, SRC_VOCAB, TGT_VOCAB, DEVICE)


# --- 4. GRADIO INTERFACE: Define the web UI layout ---

iface = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(
        lines=3,
        label="Teluguish Input (Roman Script)",
        placeholder="Type a sentence here... for example, 'namaste andaru bagunnara'"
    ),
    outputs=gr.Textbox(
        label="Telugu Script Output"
    ),
    title="Teluguish to Telugu Transliteration with GNN ",
    description="This app uses a Graph Neural Network (GNN) to convert Telugu written in the Roman script (also known as 'Teluguish') into the authentic Telugu script. ",
    allow_flagging="never"
)

# --- 5. LAUNCH THE APP ---
if __name__ == "__main__":
    iface.launch()