TransLiTE / src /model.py
Anudeep Tippabathuni
Modified imports
dc29eb1
# src/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv # you can switch to GraphConv or GATConv
from .config import SRC_CHAR_EMB, ENC_HIDDEN, TGT_CHAR_EMB, DEC_HIDDEN, GNN_LAYERS, DROPOUT
class GNNEncoder(nn.Module):
def __init__(self, src_vocab_size, emb_dim=SRC_CHAR_EMB,
hidden_dim=ENC_HIDDEN, num_layers=GNN_LAYERS, dropout=DROPOUT, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=pad_idx)
self.gnn_layers = nn.ModuleList()
in_dim = emb_dim
for _ in range(num_layers):
self.gnn_layers.append(SAGEConv(in_dim, hidden_dim))
in_dim = hidden_dim
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index, batch_vec):
"""
x: [total_nodes] - char ids
edge_index: [2, total_edges]
batch_vec: [total_nodes] - graph index for each node
Returns:
encoder_outputs: [B, S, hidden_dim] (padded per word)
src_masks : [B, S] (1 where valid nodes)
"""
device = x.device
node_emb = self.embedding(x) # [total_nodes, emb_dim]
h = node_emb
for layer in self.gnn_layers:
h = layer(h, edge_index)
h = F.relu(h)
h = self.dropout(h)
# group per graph (word)
batch_size = batch_vec.max().item() + 1
# compute max nodes per graph
max_nodes = 0
for i in range(batch_size):
max_nodes = max(max_nodes, (batch_vec == i).sum().item())
hidden_dim = h.size(1)
encoder_outputs = torch.zeros(batch_size, max_nodes, hidden_dim, device=device)
src_masks = torch.zeros(batch_size, max_nodes, dtype=torch.bool, device=device)
for i in range(batch_size):
idx = (batch_vec == i).nonzero(as_tuple=False).squeeze(-1)
seq_len = idx.size(0)
encoder_outputs[i, :seq_len, :] = h[idx]
src_masks[i, :seq_len] = True
return encoder_outputs, src_masks
class BahdanauAttention(nn.Module):
def __init__(self, enc_dim, dec_dim):
super().__init__()
self.attn = nn.Linear(enc_dim + dec_dim, dec_dim)
self.v = nn.Linear(dec_dim, 1, bias=False)
def forward(self, hidden, encoder_outputs, mask):
"""
hidden: [B, dec_dim]
encoder_outputs: [B, S, enc_dim]
mask: [B, S]
"""
B, S, _ = encoder_outputs.size()
hidden_exp = hidden.unsqueeze(1).repeat(1, S, 1) # [B,S,dec_dim]
energy = torch.tanh(self.attn(torch.cat((hidden_exp, encoder_outputs), dim=2))) # [B,S,dec_dim]
scores = self.v(energy).squeeze(-1) # [B,S]
scores = scores.masked_fill(~mask, -1e9)
attn_weights = torch.softmax(scores, dim=1) # [B,S]
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # [B,1,enc_dim]
context = context.squeeze(1) # [B,enc_dim]
return context, attn_weights
class Decoder(nn.Module):
def __init__(self, tgt_vocab_size, emb_dim=TGT_CHAR_EMB,
enc_dim=ENC_HIDDEN, dec_dim=DEC_HIDDEN, dropout=DROPOUT, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(tgt_vocab_size, emb_dim, padding_idx=pad_idx)
self.rnn = nn.LSTM(emb_dim + enc_dim, dec_dim, batch_first=True)
self.attn = BahdanauAttention(enc_dim, dec_dim)
self.fc_out = nn.Linear(dec_dim, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input_token, hidden, cell, encoder_outputs, mask):
"""
input_token: [B]
hidden, cell: [1,B,dec_dim]
encoder_outputs: [B,S,enc_dim]
mask: [B,S]
"""
emb = self.dropout(self.embedding(input_token)).unsqueeze(1) # [B,1,E]
dec_hidden = hidden[-1] # [B,dec_dim]
context, attn_weights = self.attn(dec_hidden, encoder_outputs, mask) # [B,enc_dim]
context = context.unsqueeze(1) # [B,1,enc_dim]
rnn_input = torch.cat((emb, context), dim=2) # [B,1,E+enc_dim]
output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
logits = self.fc_out(output.squeeze(1)) # [B, vocab]
return logits, hidden, cell, attn_weights
class Seq2Seq(nn.Module):
def __init__(self, encoder: GNNEncoder, decoder: Decoder,
tgt_sos_idx: int, tgt_eos_idx: int, tgt_pad_idx: int):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.tgt_sos_idx = tgt_sos_idx
self.tgt_eos_idx = tgt_eos_idx
self.tgt_pad_idx = tgt_pad_idx
def forward(self, x, edge_index, batch_vec, tgt_padded, teacher_forcing_ratio=0.5):
"""
x: [total_nodes]
edge_index: [2, total_edges]
batch_vec: [total_nodes]
tgt_padded: [B, T]
"""
device = x.device
B, T = tgt_padded.size()
encoder_outputs, src_mask = self.encoder(x, edge_index, batch_vec) # [B,S,H], [B,S]
# init decoder hidden, cell as zeros
hidden = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device)
cell = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device)
outputs = torch.zeros(B, T, self.decoder.fc_out.out_features, device=device)
input_token = tgt_padded[:, 0] # <sos>
for t in range(1, T):
logits, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs, src_mask)
outputs[:, t, :] = logits
teacher_force = torch.rand(1).item() < teacher_forcing_ratio
top1 = logits.argmax(dim=-1)
input_token = tgt_padded[:, t] if teacher_force else top1
return outputs # [B,T,V]