import gradio as gr import torch import re import setup from model import create_model from dataset import TeluguishDataset, Vocabulary # Required for loading the saved vocab objects from torch_geometric.data import Data # --- 1. GLOBAL SETUP: Load the model and vocabularies only ONCE when the app starts --- print("--- Initializing Gradio Application ---") # Set the device DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {DEVICE}") # Load the vocabularies from the dataset's processed file print("Loading vocabularies...") DATASET = TeluguishDataset(root='.') SRC_VOCAB = DATASET.src_vocab TGT_VOCAB = DATASET.tgt_vocab print("Vocabularies loaded.") # Load the trained model print("Loading trained model...") # Create the model structure with the correct vocabulary sizes MODEL = create_model(len(SRC_VOCAB), len(TGT_VOCAB)) # Load the saved weights into the model structure MODEL.load_state_dict(torch.load(setup.MODEL_PATH, map_location=DEVICE)) MODEL.to(DEVICE) MODEL.eval() # Set the model to evaluation mode permanently print("Model loaded. Application is ready to launch.") # ----------------------------------------------------------------------------------------- # --- 2. CORE LOGIC: The helper functions needed for inference --- def create_graph_from_text(text, vocab): """ Converts a single Teluguish string into a PyG Data object for inference. """ tokens = list(text) node_features = torch.tensor([vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long) edge_list = [] if len(tokens) > 1: for i in range(len(tokens) - 1): edge_list.extend([[i, i + 1], [i + 1, i]]) # Bidirectional edges edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() return Data(x=node_features, edge_index=edge_index) def transliterate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=100): """ Performs the transliteration for a single sentence. """ # Preprocess the input sentence sentence = sentence.lower().strip() sentence = re.sub(r'[^a-z\s]', '', sentence) if not sentence: return "" # Create a graph from the input sentence src_graph = create_graph_from_text(sentence, src_vocab) src_graph = src_graph.to(device) src_batch_vector = torch.zeros(src_graph.x.shape[0], dtype=torch.long).to(device) with torch.no_grad(): _, graph_embedding = model.encoder(src_graph.x, src_graph.edge_index, src_batch_vector) hidden = graph_embedding.unsqueeze(0) cell = torch.zeros(1, 1, setup.HIDDEN_DIM).to(device) trg_indexes = [setup.SOS_TOKEN] for _ in range(max_len): trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device) with torch.no_grad(): output, hidden, cell = model.decoder(trg_tensor, hidden, cell, graph_embedding) pred_token = output.argmax(1).item() trg_indexes.append(pred_token) if pred_token == setup.EOS_TOKEN: break trg_tokens = [tgt_vocab.itos[i] for i in trg_indexes] return "".join(trg_tokens[1:-1]) # --- 3. GRADIO FUNCTION: The main function that the UI will call --- def predict(teluguish_text): """ The main prediction function for the Gradio interface. It uses the globally loaded model and vocabularies. """ if not teluguish_text: return "" # Call the core logic function to get the transliteration return transliterate_sentence(MODEL, teluguish_text, SRC_VOCAB, TGT_VOCAB, DEVICE) # --- 4. GRADIO INTERFACE: Define the web UI layout --- iface = gr.Interface( fn=predict, inputs=gr.Textbox( lines=3, label="Teluguish Input (Roman Script)", placeholder="Type a sentence here... for example, 'namaste andaru bagunnara'" ), outputs=gr.Textbox( label="Telugu Script Output" ), title="Teluguish to Telugu Transliteration with GNN ", description="This app uses a Graph Neural Network (GNN) to convert Telugu written in the Roman script (also known as 'Teluguish') into the authentic Telugu script. ", allow_flagging="never" ) # --- 5. LAUNCH THE APP --- if __name__ == "__main__": iface.launch()