import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GATConv, global_mean_pool import random import setup class GNNEncoder(nn.Module): def __init__(self, input_dim, embed_dim, hidden_dim, n_layers, dropout): super().__init__() self.embedding = nn.Embedding(input_dim, embed_dim) self.convs = nn.ModuleList() self.convs.append(GATConv(embed_dim, hidden_dim)) for _ in range(n_layers - 1): self.convs.append(GATConv(hidden_dim, hidden_dim)) self.dropout = nn.Dropout(dropout) def forward(self, x, edge_index, batch): x = self.embedding(x) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = self.dropout(x) graph_embedding = global_mean_pool(x, batch) return x, graph_embedding class Decoder(nn.Module): def __init__(self, output_dim, embed_dim, hidden_dim, dropout): super().__init__() self.output_dim = output_dim self.embedding = nn.Embedding(output_dim, embed_dim) self.rnn = nn.LSTM(embed_dim + hidden_dim, hidden_dim) self.fc_out = nn.Linear(hidden_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden, cell, context): input = input.unsqueeze(0) embedded = self.dropout(self.embedding(input)) context = context.unsqueeze(0) rnn_input = torch.cat((embedded, context), dim=2) output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) prediction = self.fc_out(output.squeeze(0)) return prediction, hidden, cell class Seq2Seq(nn.Module): # ... (No changes inside this class) def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def forward(self, src_nodes, src_edge_index, src_batch, trg, teacher_forcing_ratio=0.5): batch_size = trg.shape[0] trg_len = trg.shape[1] trg_vocab_size = self.decoder.output_dim outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(setup.DEVICE) _, graph_embedding = self.encoder(src_nodes, src_edge_index, src_batch) hidden = graph_embedding.unsqueeze(0) cell = torch.zeros(1, batch_size, setup.HIDDEN_DIM).to(setup.DEVICE) input = trg[:, 0] for t in range(1, trg_len): output, hidden, cell = self.decoder(input, hidden, cell, graph_embedding) outputs[t] = output teacher_force = random.random() < teacher_forcing_ratio top1 = output.argmax(1) input = trg[:, t] if teacher_force else top1 return outputs def create_model(src_vocab_size, tgt_vocab_size): """Creates and initializes the Seq2Seq model.""" encoder = GNNEncoder(src_vocab_size, setup.EMBEDDING_DIM, setup.HIDDEN_DIM, setup.GNN_LAYERS, setup.DROPOUT) decoder = Decoder(tgt_vocab_size, setup.EMBEDDING_DIM, setup.HIDDEN_DIM, setup.DROPOUT) model = Seq2Seq(encoder, decoder).to(setup.DEVICE) # Initialize weights def init_weights(m): for name, param in m.named_parameters(): if 'weight' in name: nn.init.normal_(param.data, mean=0, std=0.01) else: nn.init.constant_(param.data, 0) model.apply(init_weights) return model