File size: 5,817 Bytes
beddc0e dc29eb1 beddc0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# src/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv # you can switch to GraphConv or GATConv
from .config import SRC_CHAR_EMB, ENC_HIDDEN, TGT_CHAR_EMB, DEC_HIDDEN, GNN_LAYERS, DROPOUT
class GNNEncoder(nn.Module):
def __init__(self, src_vocab_size, emb_dim=SRC_CHAR_EMB,
hidden_dim=ENC_HIDDEN, num_layers=GNN_LAYERS, dropout=DROPOUT, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=pad_idx)
self.gnn_layers = nn.ModuleList()
in_dim = emb_dim
for _ in range(num_layers):
self.gnn_layers.append(SAGEConv(in_dim, hidden_dim))
in_dim = hidden_dim
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index, batch_vec):
"""
x: [total_nodes] - char ids
edge_index: [2, total_edges]
batch_vec: [total_nodes] - graph index for each node
Returns:
encoder_outputs: [B, S, hidden_dim] (padded per word)
src_masks : [B, S] (1 where valid nodes)
"""
device = x.device
node_emb = self.embedding(x) # [total_nodes, emb_dim]
h = node_emb
for layer in self.gnn_layers:
h = layer(h, edge_index)
h = F.relu(h)
h = self.dropout(h)
# group per graph (word)
batch_size = batch_vec.max().item() + 1
# compute max nodes per graph
max_nodes = 0
for i in range(batch_size):
max_nodes = max(max_nodes, (batch_vec == i).sum().item())
hidden_dim = h.size(1)
encoder_outputs = torch.zeros(batch_size, max_nodes, hidden_dim, device=device)
src_masks = torch.zeros(batch_size, max_nodes, dtype=torch.bool, device=device)
for i in range(batch_size):
idx = (batch_vec == i).nonzero(as_tuple=False).squeeze(-1)
seq_len = idx.size(0)
encoder_outputs[i, :seq_len, :] = h[idx]
src_masks[i, :seq_len] = True
return encoder_outputs, src_masks
class BahdanauAttention(nn.Module):
def __init__(self, enc_dim, dec_dim):
super().__init__()
self.attn = nn.Linear(enc_dim + dec_dim, dec_dim)
self.v = nn.Linear(dec_dim, 1, bias=False)
def forward(self, hidden, encoder_outputs, mask):
"""
hidden: [B, dec_dim]
encoder_outputs: [B, S, enc_dim]
mask: [B, S]
"""
B, S, _ = encoder_outputs.size()
hidden_exp = hidden.unsqueeze(1).repeat(1, S, 1) # [B,S,dec_dim]
energy = torch.tanh(self.attn(torch.cat((hidden_exp, encoder_outputs), dim=2))) # [B,S,dec_dim]
scores = self.v(energy).squeeze(-1) # [B,S]
scores = scores.masked_fill(~mask, -1e9)
attn_weights = torch.softmax(scores, dim=1) # [B,S]
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # [B,1,enc_dim]
context = context.squeeze(1) # [B,enc_dim]
return context, attn_weights
class Decoder(nn.Module):
def __init__(self, tgt_vocab_size, emb_dim=TGT_CHAR_EMB,
enc_dim=ENC_HIDDEN, dec_dim=DEC_HIDDEN, dropout=DROPOUT, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(tgt_vocab_size, emb_dim, padding_idx=pad_idx)
self.rnn = nn.LSTM(emb_dim + enc_dim, dec_dim, batch_first=True)
self.attn = BahdanauAttention(enc_dim, dec_dim)
self.fc_out = nn.Linear(dec_dim, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input_token, hidden, cell, encoder_outputs, mask):
"""
input_token: [B]
hidden, cell: [1,B,dec_dim]
encoder_outputs: [B,S,enc_dim]
mask: [B,S]
"""
emb = self.dropout(self.embedding(input_token)).unsqueeze(1) # [B,1,E]
dec_hidden = hidden[-1] # [B,dec_dim]
context, attn_weights = self.attn(dec_hidden, encoder_outputs, mask) # [B,enc_dim]
context = context.unsqueeze(1) # [B,1,enc_dim]
rnn_input = torch.cat((emb, context), dim=2) # [B,1,E+enc_dim]
output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
logits = self.fc_out(output.squeeze(1)) # [B, vocab]
return logits, hidden, cell, attn_weights
class Seq2Seq(nn.Module):
def __init__(self, encoder: GNNEncoder, decoder: Decoder,
tgt_sos_idx: int, tgt_eos_idx: int, tgt_pad_idx: int):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.tgt_sos_idx = tgt_sos_idx
self.tgt_eos_idx = tgt_eos_idx
self.tgt_pad_idx = tgt_pad_idx
def forward(self, x, edge_index, batch_vec, tgt_padded, teacher_forcing_ratio=0.5):
"""
x: [total_nodes]
edge_index: [2, total_edges]
batch_vec: [total_nodes]
tgt_padded: [B, T]
"""
device = x.device
B, T = tgt_padded.size()
encoder_outputs, src_mask = self.encoder(x, edge_index, batch_vec) # [B,S,H], [B,S]
# init decoder hidden, cell as zeros
hidden = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device)
cell = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device)
outputs = torch.zeros(B, T, self.decoder.fc_out.out_features, device=device)
input_token = tgt_padded[:, 0] # <sos>
for t in range(1, T):
logits, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs, src_mask)
outputs[:, t, :] = logits
teacher_force = torch.rand(1).item() < teacher_forcing_ratio
top1 = logits.argmax(dim=-1)
input_token = tgt_padded[:, t] if teacher_force else top1
return outputs # [B,T,V]
|