# src/model.py import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import SAGEConv # you can switch to GraphConv or GATConv from .config import SRC_CHAR_EMB, ENC_HIDDEN, TGT_CHAR_EMB, DEC_HIDDEN, GNN_LAYERS, DROPOUT class GNNEncoder(nn.Module): def __init__(self, src_vocab_size, emb_dim=SRC_CHAR_EMB, hidden_dim=ENC_HIDDEN, num_layers=GNN_LAYERS, dropout=DROPOUT, pad_idx=0): super().__init__() self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=pad_idx) self.gnn_layers = nn.ModuleList() in_dim = emb_dim for _ in range(num_layers): self.gnn_layers.append(SAGEConv(in_dim, hidden_dim)) in_dim = hidden_dim self.dropout = nn.Dropout(dropout) def forward(self, x, edge_index, batch_vec): """ x: [total_nodes] - char ids edge_index: [2, total_edges] batch_vec: [total_nodes] - graph index for each node Returns: encoder_outputs: [B, S, hidden_dim] (padded per word) src_masks : [B, S] (1 where valid nodes) """ device = x.device node_emb = self.embedding(x) # [total_nodes, emb_dim] h = node_emb for layer in self.gnn_layers: h = layer(h, edge_index) h = F.relu(h) h = self.dropout(h) # group per graph (word) batch_size = batch_vec.max().item() + 1 # compute max nodes per graph max_nodes = 0 for i in range(batch_size): max_nodes = max(max_nodes, (batch_vec == i).sum().item()) hidden_dim = h.size(1) encoder_outputs = torch.zeros(batch_size, max_nodes, hidden_dim, device=device) src_masks = torch.zeros(batch_size, max_nodes, dtype=torch.bool, device=device) for i in range(batch_size): idx = (batch_vec == i).nonzero(as_tuple=False).squeeze(-1) seq_len = idx.size(0) encoder_outputs[i, :seq_len, :] = h[idx] src_masks[i, :seq_len] = True return encoder_outputs, src_masks class BahdanauAttention(nn.Module): def __init__(self, enc_dim, dec_dim): super().__init__() self.attn = nn.Linear(enc_dim + dec_dim, dec_dim) self.v = nn.Linear(dec_dim, 1, bias=False) def forward(self, hidden, encoder_outputs, mask): """ hidden: [B, dec_dim] encoder_outputs: [B, S, enc_dim] mask: [B, S] """ B, S, _ = encoder_outputs.size() hidden_exp = hidden.unsqueeze(1).repeat(1, S, 1) # [B,S,dec_dim] energy = torch.tanh(self.attn(torch.cat((hidden_exp, encoder_outputs), dim=2))) # [B,S,dec_dim] scores = self.v(energy).squeeze(-1) # [B,S] scores = scores.masked_fill(~mask, -1e9) attn_weights = torch.softmax(scores, dim=1) # [B,S] context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # [B,1,enc_dim] context = context.squeeze(1) # [B,enc_dim] return context, attn_weights class Decoder(nn.Module): def __init__(self, tgt_vocab_size, emb_dim=TGT_CHAR_EMB, enc_dim=ENC_HIDDEN, dec_dim=DEC_HIDDEN, dropout=DROPOUT, pad_idx=0): super().__init__() self.embedding = nn.Embedding(tgt_vocab_size, emb_dim, padding_idx=pad_idx) self.rnn = nn.LSTM(emb_dim + enc_dim, dec_dim, batch_first=True) self.attn = BahdanauAttention(enc_dim, dec_dim) self.fc_out = nn.Linear(dec_dim, tgt_vocab_size) self.dropout = nn.Dropout(dropout) def forward(self, input_token, hidden, cell, encoder_outputs, mask): """ input_token: [B] hidden, cell: [1,B,dec_dim] encoder_outputs: [B,S,enc_dim] mask: [B,S] """ emb = self.dropout(self.embedding(input_token)).unsqueeze(1) # [B,1,E] dec_hidden = hidden[-1] # [B,dec_dim] context, attn_weights = self.attn(dec_hidden, encoder_outputs, mask) # [B,enc_dim] context = context.unsqueeze(1) # [B,1,enc_dim] rnn_input = torch.cat((emb, context), dim=2) # [B,1,E+enc_dim] output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) logits = self.fc_out(output.squeeze(1)) # [B, vocab] return logits, hidden, cell, attn_weights class Seq2Seq(nn.Module): def __init__(self, encoder: GNNEncoder, decoder: Decoder, tgt_sos_idx: int, tgt_eos_idx: int, tgt_pad_idx: int): super().__init__() self.encoder = encoder self.decoder = decoder self.tgt_sos_idx = tgt_sos_idx self.tgt_eos_idx = tgt_eos_idx self.tgt_pad_idx = tgt_pad_idx def forward(self, x, edge_index, batch_vec, tgt_padded, teacher_forcing_ratio=0.5): """ x: [total_nodes] edge_index: [2, total_edges] batch_vec: [total_nodes] tgt_padded: [B, T] """ device = x.device B, T = tgt_padded.size() encoder_outputs, src_mask = self.encoder(x, edge_index, batch_vec) # [B,S,H], [B,S] # init decoder hidden, cell as zeros hidden = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device) cell = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device) outputs = torch.zeros(B, T, self.decoder.fc_out.out_features, device=device) input_token = tgt_padded[:, 0] # for t in range(1, T): logits, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs, src_mask) outputs[:, t, :] = logits teacher_force = torch.rand(1).item() < teacher_forcing_ratio top1 = logits.argmax(dim=-1) input_token = tgt_padded[:, t] if teacher_force else top1 return outputs # [B,T,V]