import torch import re import setup from model import create_model from dataset import TeluguishDataset, Vocabulary from torch_geometric.data import Data def create_graph_from_text(text, vocab): """ Converts a single Teluguish string into a PyG Data object for inference. """ tokens = list(text) # Convert tokens to indices using the loaded vocabulary node_features = torch.tensor([vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long) # Create sequential edges edge_list = [] if len(tokens) > 1: for i in range(len(tokens) - 1): edge_list.extend([[i, i + 1], [i + 1, i]]) # Bidirectional edges edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() # Create a Data object return Data(x=node_features, edge_index=edge_index) def transliterate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=100): """ Performs the transliteration for a single sentence. """ model.eval() # Set the model to evaluation mode # Preprocess the input sentence sentence = sentence.lower().strip() sentence = re.sub(r'[^a-z\s]', '', sentence) if not sentence: return "" # Create a graph from the input sentence src_graph = create_graph_from_text(sentence, src_vocab) src_graph = src_graph.to(device) # The batch vector for a single graph is just a tensor of zeros src_batch_vector = torch.zeros(src_graph.x.shape[0], dtype=torch.long).to(device) with torch.no_grad(): # Pass the graph through the encoder to get the context _, graph_embedding = model.encoder(src_graph.x, src_graph.edge_index, src_batch_vector) # Use the graph embedding as the initial hidden and cell states for the decoder hidden = graph_embedding.unsqueeze(0) cell = torch.zeros(1, 1, setup.HIDDEN_DIM).to(device) # Start the output sequence with the token trg_indexes = [setup.SOS_TOKEN] for _ in range(max_len): # Get the last predicted token trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device) with torch.no_grad(): # Pass the previous token, hidden states, and context to the decoder output, hidden, cell = model.decoder(trg_tensor, hidden, cell, graph_embedding) # Get the single most likely token pred_token = output.argmax(1).item() trg_indexes.append(pred_token) # Stop if the model predicts the token if pred_token == setup.EOS_TOKEN: break # Convert the predicted indices back to characters using the target vocabulary trg_tokens = [tgt_vocab.itos[i] for i in trg_indexes] # Return the final string, excluding the and tokens return "".join(trg_tokens[1:-1]) # --- Main execution block --- if __name__ == '__main__': # Set the device (GPU or CPU) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # --- Load Vocabularies --- # We load the vocabularies from the dataset's processed file print("Loading vocabularies...") dataset = TeluguishDataset(root='.') src_vocab = dataset.src_vocab tgt_vocab = dataset.tgt_vocab print("Vocabularies loaded.") # --- Load Model --- print("Loading trained model...") # First, create the model structure with the correct vocabulary sizes model = create_model(len(src_vocab), len(tgt_vocab)) # Then, load the saved weights model.load_state_dict(torch.load(setup.MODEL_PATH, map_location=device)) model.to(device) print("Model loaded.") # --- Example Usage --- print("\n--- Transliteration Examples ---") input_sentence_1 = "akada mediatho matladina anantaram tirigi Hyderabad bayaluderutaaru" output_sentence_1 = transliterate_sentence(model, input_sentence_1, src_vocab, tgt_vocab, device) print(f"\nInput: {input_sentence_1}") print(f"Output: {output_sentence_1}") input_sentence_2 = "ee roju ma intlo bhojanam cheyandi" output_sentence_2 = transliterate_sentence(model, input_sentence_2, src_vocab, tgt_vocab, device) print(f"\nInput: {input_sentence_2}") print(f"Output: {output_sentence_2}") input_sentence_3 = "namaste andaru bagunnara" output_sentence_3 = transliterate_sentence(model, input_sentence_3, src_vocab, tgt_vocab, device) print(f"\nInput: {input_sentence_3}") print(f"Output: {output_sentence_3}")