Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import re | |
| import setup | |
| from model import create_model | |
| from dataset import TeluguishDataset, Vocabulary # Required for loading the saved vocab objects | |
| from torch_geometric.data import Data | |
| # --- 1. GLOBAL SETUP: Load the model and vocabularies only ONCE when the app starts --- | |
| print("--- Initializing Gradio Application ---") | |
| # Set the device | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {DEVICE}") | |
| # Load the vocabularies from the dataset's processed file | |
| print("Loading vocabularies...") | |
| DATASET = TeluguishDataset(root='.') | |
| SRC_VOCAB = DATASET.src_vocab | |
| TGT_VOCAB = DATASET.tgt_vocab | |
| print("Vocabularies loaded.") | |
| # Load the trained model | |
| print("Loading trained model...") | |
| # Create the model structure with the correct vocabulary sizes | |
| MODEL = create_model(len(SRC_VOCAB), len(TGT_VOCAB)) | |
| # Load the saved weights into the model structure | |
| MODEL.load_state_dict(torch.load(setup.MODEL_PATH, map_location=DEVICE)) | |
| MODEL.to(DEVICE) | |
| MODEL.eval() # Set the model to evaluation mode permanently | |
| print("Model loaded. Application is ready to launch.") | |
| # ----------------------------------------------------------------------------------------- | |
| # --- 2. CORE LOGIC: The helper functions needed for inference --- | |
| def create_graph_from_text(text, vocab): | |
| """ | |
| Converts a single Teluguish string into a PyG Data object for inference. | |
| """ | |
| tokens = list(text) | |
| node_features = torch.tensor([vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long) | |
| edge_list = [] | |
| if len(tokens) > 1: | |
| for i in range(len(tokens) - 1): | |
| edge_list.extend([[i, i + 1], [i + 1, i]]) # Bidirectional edges | |
| edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() | |
| return Data(x=node_features, edge_index=edge_index) | |
| def transliterate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=100): | |
| """ | |
| Performs the transliteration for a single sentence. | |
| """ | |
| # Preprocess the input sentence | |
| sentence = sentence.lower().strip() | |
| sentence = re.sub(r'[^a-z\s]', '', sentence) | |
| if not sentence: | |
| return "" | |
| # Create a graph from the input sentence | |
| src_graph = create_graph_from_text(sentence, src_vocab) | |
| src_graph = src_graph.to(device) | |
| src_batch_vector = torch.zeros(src_graph.x.shape[0], dtype=torch.long).to(device) | |
| with torch.no_grad(): | |
| _, graph_embedding = model.encoder(src_graph.x, src_graph.edge_index, src_batch_vector) | |
| hidden = graph_embedding.unsqueeze(0) | |
| cell = torch.zeros(1, 1, setup.HIDDEN_DIM).to(device) | |
| trg_indexes = [setup.SOS_TOKEN] | |
| for _ in range(max_len): | |
| trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device) | |
| with torch.no_grad(): | |
| output, hidden, cell = model.decoder(trg_tensor, hidden, cell, graph_embedding) | |
| pred_token = output.argmax(1).item() | |
| trg_indexes.append(pred_token) | |
| if pred_token == setup.EOS_TOKEN: | |
| break | |
| trg_tokens = [tgt_vocab.itos[i] for i in trg_indexes] | |
| return "".join(trg_tokens[1:-1]) | |
| # --- 3. GRADIO FUNCTION: The main function that the UI will call --- | |
| def predict(teluguish_text): | |
| """ | |
| The main prediction function for the Gradio interface. | |
| It uses the globally loaded model and vocabularies. | |
| """ | |
| if not teluguish_text: | |
| return "" | |
| # Call the core logic function to get the transliteration | |
| return transliterate_sentence(MODEL, teluguish_text, SRC_VOCAB, TGT_VOCAB, DEVICE) | |
| # --- 4. GRADIO INTERFACE: Define the web UI layout --- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox( | |
| lines=3, | |
| label="Teluguish Input (Roman Script)", | |
| placeholder="Type a sentence here... for example, 'namaste andaru bagunnara'" | |
| ), | |
| outputs=gr.Textbox( | |
| label="Telugu Script Output" | |
| ), | |
| title="Teluguish to Telugu Transliteration with GNN ", | |
| description="This app uses a Graph Neural Network (GNN) to convert Telugu written in the Roman script (also known as 'Teluguish') into the authentic Telugu script. ", | |
| allow_flagging="never" | |
| ) | |
| # --- 5. LAUNCH THE APP --- | |
| if __name__ == "__main__": | |
| iface.launch() |