|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch_geometric.nn import SAGEConv |
|
|
|
|
|
from .config import SRC_CHAR_EMB, ENC_HIDDEN, TGT_CHAR_EMB, DEC_HIDDEN, GNN_LAYERS, DROPOUT |
|
|
|
|
|
|
|
|
class GNNEncoder(nn.Module): |
|
|
def __init__(self, src_vocab_size, emb_dim=SRC_CHAR_EMB, |
|
|
hidden_dim=ENC_HIDDEN, num_layers=GNN_LAYERS, dropout=DROPOUT, pad_idx=0): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=pad_idx) |
|
|
self.gnn_layers = nn.ModuleList() |
|
|
in_dim = emb_dim |
|
|
for _ in range(num_layers): |
|
|
self.gnn_layers.append(SAGEConv(in_dim, hidden_dim)) |
|
|
in_dim = hidden_dim |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x, edge_index, batch_vec): |
|
|
""" |
|
|
x: [total_nodes] - char ids |
|
|
edge_index: [2, total_edges] |
|
|
batch_vec: [total_nodes] - graph index for each node |
|
|
Returns: |
|
|
encoder_outputs: [B, S, hidden_dim] (padded per word) |
|
|
src_masks : [B, S] (1 where valid nodes) |
|
|
""" |
|
|
device = x.device |
|
|
node_emb = self.embedding(x) |
|
|
h = node_emb |
|
|
for layer in self.gnn_layers: |
|
|
h = layer(h, edge_index) |
|
|
h = F.relu(h) |
|
|
h = self.dropout(h) |
|
|
|
|
|
|
|
|
batch_size = batch_vec.max().item() + 1 |
|
|
|
|
|
max_nodes = 0 |
|
|
for i in range(batch_size): |
|
|
max_nodes = max(max_nodes, (batch_vec == i).sum().item()) |
|
|
|
|
|
hidden_dim = h.size(1) |
|
|
encoder_outputs = torch.zeros(batch_size, max_nodes, hidden_dim, device=device) |
|
|
src_masks = torch.zeros(batch_size, max_nodes, dtype=torch.bool, device=device) |
|
|
|
|
|
for i in range(batch_size): |
|
|
idx = (batch_vec == i).nonzero(as_tuple=False).squeeze(-1) |
|
|
seq_len = idx.size(0) |
|
|
encoder_outputs[i, :seq_len, :] = h[idx] |
|
|
src_masks[i, :seq_len] = True |
|
|
|
|
|
return encoder_outputs, src_masks |
|
|
|
|
|
|
|
|
class BahdanauAttention(nn.Module): |
|
|
def __init__(self, enc_dim, dec_dim): |
|
|
super().__init__() |
|
|
self.attn = nn.Linear(enc_dim + dec_dim, dec_dim) |
|
|
self.v = nn.Linear(dec_dim, 1, bias=False) |
|
|
|
|
|
def forward(self, hidden, encoder_outputs, mask): |
|
|
""" |
|
|
hidden: [B, dec_dim] |
|
|
encoder_outputs: [B, S, enc_dim] |
|
|
mask: [B, S] |
|
|
""" |
|
|
B, S, _ = encoder_outputs.size() |
|
|
hidden_exp = hidden.unsqueeze(1).repeat(1, S, 1) |
|
|
energy = torch.tanh(self.attn(torch.cat((hidden_exp, encoder_outputs), dim=2))) |
|
|
scores = self.v(energy).squeeze(-1) |
|
|
|
|
|
scores = scores.masked_fill(~mask, -1e9) |
|
|
attn_weights = torch.softmax(scores, dim=1) |
|
|
|
|
|
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) |
|
|
context = context.squeeze(1) |
|
|
return context, attn_weights |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
def __init__(self, tgt_vocab_size, emb_dim=TGT_CHAR_EMB, |
|
|
enc_dim=ENC_HIDDEN, dec_dim=DEC_HIDDEN, dropout=DROPOUT, pad_idx=0): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(tgt_vocab_size, emb_dim, padding_idx=pad_idx) |
|
|
self.rnn = nn.LSTM(emb_dim + enc_dim, dec_dim, batch_first=True) |
|
|
self.attn = BahdanauAttention(enc_dim, dec_dim) |
|
|
self.fc_out = nn.Linear(dec_dim, tgt_vocab_size) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, input_token, hidden, cell, encoder_outputs, mask): |
|
|
""" |
|
|
input_token: [B] |
|
|
hidden, cell: [1,B,dec_dim] |
|
|
encoder_outputs: [B,S,enc_dim] |
|
|
mask: [B,S] |
|
|
""" |
|
|
emb = self.dropout(self.embedding(input_token)).unsqueeze(1) |
|
|
dec_hidden = hidden[-1] |
|
|
context, attn_weights = self.attn(dec_hidden, encoder_outputs, mask) |
|
|
context = context.unsqueeze(1) |
|
|
rnn_input = torch.cat((emb, context), dim=2) |
|
|
output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) |
|
|
logits = self.fc_out(output.squeeze(1)) |
|
|
return logits, hidden, cell, attn_weights |
|
|
|
|
|
|
|
|
class Seq2Seq(nn.Module): |
|
|
def __init__(self, encoder: GNNEncoder, decoder: Decoder, |
|
|
tgt_sos_idx: int, tgt_eos_idx: int, tgt_pad_idx: int): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
self.tgt_sos_idx = tgt_sos_idx |
|
|
self.tgt_eos_idx = tgt_eos_idx |
|
|
self.tgt_pad_idx = tgt_pad_idx |
|
|
|
|
|
def forward(self, x, edge_index, batch_vec, tgt_padded, teacher_forcing_ratio=0.5): |
|
|
""" |
|
|
x: [total_nodes] |
|
|
edge_index: [2, total_edges] |
|
|
batch_vec: [total_nodes] |
|
|
tgt_padded: [B, T] |
|
|
""" |
|
|
device = x.device |
|
|
B, T = tgt_padded.size() |
|
|
encoder_outputs, src_mask = self.encoder(x, edge_index, batch_vec) |
|
|
|
|
|
|
|
|
hidden = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device) |
|
|
cell = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device) |
|
|
|
|
|
outputs = torch.zeros(B, T, self.decoder.fc_out.out_features, device=device) |
|
|
|
|
|
input_token = tgt_padded[:, 0] |
|
|
for t in range(1, T): |
|
|
logits, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs, src_mask) |
|
|
outputs[:, t, :] = logits |
|
|
teacher_force = torch.rand(1).item() < teacher_forcing_ratio |
|
|
top1 = logits.argmax(dim=-1) |
|
|
input_token = tgt_padded[:, t] if teacher_force else top1 |
|
|
|
|
|
return outputs |
|
|
|