Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import GATConv, global_mean_pool | |
| import random | |
| import setup | |
| class GNNEncoder(nn.Module): | |
| def __init__(self, input_dim, embed_dim, hidden_dim, n_layers, dropout): | |
| super().__init__() | |
| self.embedding = nn.Embedding(input_dim, embed_dim) | |
| self.convs = nn.ModuleList() | |
| self.convs.append(GATConv(embed_dim, hidden_dim)) | |
| for _ in range(n_layers - 1): | |
| self.convs.append(GATConv(hidden_dim, hidden_dim)) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, edge_index, batch): | |
| x = self.embedding(x) | |
| for conv in self.convs: | |
| x = F.relu(conv(x, edge_index)) | |
| x = self.dropout(x) | |
| graph_embedding = global_mean_pool(x, batch) | |
| return x, graph_embedding | |
| class Decoder(nn.Module): | |
| def __init__(self, output_dim, embed_dim, hidden_dim, dropout): | |
| super().__init__() | |
| self.output_dim = output_dim | |
| self.embedding = nn.Embedding(output_dim, embed_dim) | |
| self.rnn = nn.LSTM(embed_dim + hidden_dim, hidden_dim) | |
| self.fc_out = nn.Linear(hidden_dim, output_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, input, hidden, cell, context): | |
| input = input.unsqueeze(0) | |
| embedded = self.dropout(self.embedding(input)) | |
| context = context.unsqueeze(0) | |
| rnn_input = torch.cat((embedded, context), dim=2) | |
| output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) | |
| prediction = self.fc_out(output.squeeze(0)) | |
| return prediction, hidden, cell | |
| class Seq2Seq(nn.Module): | |
| # ... (No changes inside this class) | |
| def __init__(self, encoder, decoder): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| def forward(self, src_nodes, src_edge_index, src_batch, trg, teacher_forcing_ratio=0.5): | |
| batch_size = trg.shape[0] | |
| trg_len = trg.shape[1] | |
| trg_vocab_size = self.decoder.output_dim | |
| outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(setup.DEVICE) | |
| _, graph_embedding = self.encoder(src_nodes, src_edge_index, src_batch) | |
| hidden = graph_embedding.unsqueeze(0) | |
| cell = torch.zeros(1, batch_size, setup.HIDDEN_DIM).to(setup.DEVICE) | |
| input = trg[:, 0] | |
| for t in range(1, trg_len): | |
| output, hidden, cell = self.decoder(input, hidden, cell, graph_embedding) | |
| outputs[t] = output | |
| teacher_force = random.random() < teacher_forcing_ratio | |
| top1 = output.argmax(1) | |
| input = trg[:, t] if teacher_force else top1 | |
| return outputs | |
| def create_model(src_vocab_size, tgt_vocab_size): | |
| """Creates and initializes the Seq2Seq model.""" | |
| encoder = GNNEncoder(src_vocab_size, setup.EMBEDDING_DIM, setup.HIDDEN_DIM, setup.GNN_LAYERS, setup.DROPOUT) | |
| decoder = Decoder(tgt_vocab_size, setup.EMBEDDING_DIM, setup.HIDDEN_DIM, setup.DROPOUT) | |
| model = Seq2Seq(encoder, decoder).to(setup.DEVICE) | |
| # Initialize weights | |
| def init_weights(m): | |
| for name, param in m.named_parameters(): | |
| if 'weight' in name: | |
| nn.init.normal_(param.data, mean=0, std=0.01) | |
| else: | |
| nn.init.constant_(param.data, 0) | |
| model.apply(init_weights) | |
| return model |