Transliteration / model.py
Anudeep Tippabathuni
Fixed-3
d1dce40
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
import random
import setup
class GNNEncoder(nn.Module):
def __init__(self, input_dim, embed_dim, hidden_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, embed_dim)
self.convs = nn.ModuleList()
self.convs.append(GATConv(embed_dim, hidden_dim))
for _ in range(n_layers - 1):
self.convs.append(GATConv(hidden_dim, hidden_dim))
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index, batch):
x = self.embedding(x)
for conv in self.convs:
x = F.relu(conv(x, edge_index))
x = self.dropout(x)
graph_embedding = global_mean_pool(x, batch)
return x, graph_embedding
class Decoder(nn.Module):
def __init__(self, output_dim, embed_dim, hidden_dim, dropout):
super().__init__()
self.output_dim = output_dim
self.embedding = nn.Embedding(output_dim, embed_dim)
self.rnn = nn.LSTM(embed_dim + hidden_dim, hidden_dim)
self.fc_out = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, cell, context):
input = input.unsqueeze(0)
embedded = self.dropout(self.embedding(input))
context = context.unsqueeze(0)
rnn_input = torch.cat((embedded, context), dim=2)
output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
prediction = self.fc_out(output.squeeze(0))
return prediction, hidden, cell
class Seq2Seq(nn.Module):
# ... (No changes inside this class)
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src_nodes, src_edge_index, src_batch, trg, teacher_forcing_ratio=0.5):
batch_size = trg.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(setup.DEVICE)
_, graph_embedding = self.encoder(src_nodes, src_edge_index, src_batch)
hidden = graph_embedding.unsqueeze(0)
cell = torch.zeros(1, batch_size, setup.HIDDEN_DIM).to(setup.DEVICE)
input = trg[:, 0]
for t in range(1, trg_len):
output, hidden, cell = self.decoder(input, hidden, cell, graph_embedding)
outputs[t] = output
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[:, t] if teacher_force else top1
return outputs
def create_model(src_vocab_size, tgt_vocab_size):
"""Creates and initializes the Seq2Seq model."""
encoder = GNNEncoder(src_vocab_size, setup.EMBEDDING_DIM, setup.HIDDEN_DIM, setup.GNN_LAYERS, setup.DROPOUT)
decoder = Decoder(tgt_vocab_size, setup.EMBEDDING_DIM, setup.HIDDEN_DIM, setup.DROPOUT)
model = Seq2Seq(encoder, decoder).to(setup.DEVICE)
# Initialize weights
def init_weights(m):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.normal_(param.data, mean=0, std=0.01)
else:
nn.init.constant_(param.data, 0)
model.apply(init_weights)
return model