Transliteration / app.py
Anudeep Tippabathuni
updated app.py
8499f35
import gradio as gr
import torch
import re
import setup
from model import create_model
from dataset import TeluguishDataset, Vocabulary # Required for loading the saved vocab objects
from torch_geometric.data import Data
# --- 1. GLOBAL SETUP: Load the model and vocabularies only ONCE when the app starts ---
print("--- Initializing Gradio Application ---")
# Set the device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
# Load the vocabularies from the dataset's processed file
print("Loading vocabularies...")
DATASET = TeluguishDataset(root='.')
SRC_VOCAB = DATASET.src_vocab
TGT_VOCAB = DATASET.tgt_vocab
print("Vocabularies loaded.")
# Load the trained model
print("Loading trained model...")
# Create the model structure with the correct vocabulary sizes
MODEL = create_model(len(SRC_VOCAB), len(TGT_VOCAB))
# Load the saved weights into the model structure
MODEL.load_state_dict(torch.load(setup.MODEL_PATH, map_location=DEVICE))
MODEL.to(DEVICE)
MODEL.eval() # Set the model to evaluation mode permanently
print("Model loaded. Application is ready to launch.")
# -----------------------------------------------------------------------------------------
# --- 2. CORE LOGIC: The helper functions needed for inference ---
def create_graph_from_text(text, vocab):
"""
Converts a single Teluguish string into a PyG Data object for inference.
"""
tokens = list(text)
node_features = torch.tensor([vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long)
edge_list = []
if len(tokens) > 1:
for i in range(len(tokens) - 1):
edge_list.extend([[i, i + 1], [i + 1, i]]) # Bidirectional edges
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
return Data(x=node_features, edge_index=edge_index)
def transliterate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=100):
"""
Performs the transliteration for a single sentence.
"""
# Preprocess the input sentence
sentence = sentence.lower().strip()
sentence = re.sub(r'[^a-z\s]', '', sentence)
if not sentence:
return ""
# Create a graph from the input sentence
src_graph = create_graph_from_text(sentence, src_vocab)
src_graph = src_graph.to(device)
src_batch_vector = torch.zeros(src_graph.x.shape[0], dtype=torch.long).to(device)
with torch.no_grad():
_, graph_embedding = model.encoder(src_graph.x, src_graph.edge_index, src_batch_vector)
hidden = graph_embedding.unsqueeze(0)
cell = torch.zeros(1, 1, setup.HIDDEN_DIM).to(device)
trg_indexes = [setup.SOS_TOKEN]
for _ in range(max_len):
trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
with torch.no_grad():
output, hidden, cell = model.decoder(trg_tensor, hidden, cell, graph_embedding)
pred_token = output.argmax(1).item()
trg_indexes.append(pred_token)
if pred_token == setup.EOS_TOKEN:
break
trg_tokens = [tgt_vocab.itos[i] for i in trg_indexes]
return "".join(trg_tokens[1:-1])
# --- 3. GRADIO FUNCTION: The main function that the UI will call ---
def predict(teluguish_text):
"""
The main prediction function for the Gradio interface.
It uses the globally loaded model and vocabularies.
"""
if not teluguish_text:
return ""
# Call the core logic function to get the transliteration
return transliterate_sentence(MODEL, teluguish_text, SRC_VOCAB, TGT_VOCAB, DEVICE)
# --- 4. GRADIO INTERFACE: Define the web UI layout ---
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(
lines=3,
label="Teluguish Input (Roman Script)",
placeholder="Type a sentence here... for example, 'namaste andaru bagunnara'"
),
outputs=gr.Textbox(
label="Telugu Script Output"
),
title="Teluguish to Telugu Transliteration with GNN ",
description="This app uses a Graph Neural Network (GNN) to convert Telugu written in the Roman script (also known as 'Teluguish') into the authentic Telugu script. ",
allow_flagging="never"
)
# --- 5. LAUNCH THE APP ---
if __name__ == "__main__":
iface.launch()