File size: 5,817 Bytes
beddc0e
 
 
 
 
 
 
dc29eb1
beddc0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# src/model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv  # you can switch to GraphConv or GATConv

from .config import SRC_CHAR_EMB, ENC_HIDDEN, TGT_CHAR_EMB, DEC_HIDDEN, GNN_LAYERS, DROPOUT


class GNNEncoder(nn.Module):
    def __init__(self, src_vocab_size, emb_dim=SRC_CHAR_EMB,
                 hidden_dim=ENC_HIDDEN, num_layers=GNN_LAYERS, dropout=DROPOUT, pad_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=pad_idx)
        self.gnn_layers = nn.ModuleList()
        in_dim = emb_dim
        for _ in range(num_layers):
            self.gnn_layers.append(SAGEConv(in_dim, hidden_dim))
            in_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, batch_vec):
        """
        x: [total_nodes] - char ids
        edge_index: [2, total_edges]
        batch_vec: [total_nodes] - graph index for each node
        Returns:
          encoder_outputs: [B, S, hidden_dim]  (padded per word)
          src_masks      : [B, S] (1 where valid nodes)
        """
        device = x.device
        node_emb = self.embedding(x)  # [total_nodes, emb_dim]
        h = node_emb
        for layer in self.gnn_layers:
            h = layer(h, edge_index)
            h = F.relu(h)
            h = self.dropout(h)

        # group per graph (word)
        batch_size = batch_vec.max().item() + 1
        # compute max nodes per graph
        max_nodes = 0
        for i in range(batch_size):
            max_nodes = max(max_nodes, (batch_vec == i).sum().item())

        hidden_dim = h.size(1)
        encoder_outputs = torch.zeros(batch_size, max_nodes, hidden_dim, device=device)
        src_masks = torch.zeros(batch_size, max_nodes, dtype=torch.bool, device=device)

        for i in range(batch_size):
            idx = (batch_vec == i).nonzero(as_tuple=False).squeeze(-1)
            seq_len = idx.size(0)
            encoder_outputs[i, :seq_len, :] = h[idx]
            src_masks[i, :seq_len] = True

        return encoder_outputs, src_masks


class BahdanauAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim):
        super().__init__()
        self.attn = nn.Linear(enc_dim + dec_dim, dec_dim)
        self.v = nn.Linear(dec_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask):
        """
        hidden: [B, dec_dim]
        encoder_outputs: [B, S, enc_dim]
        mask: [B, S]
        """
        B, S, _ = encoder_outputs.size()
        hidden_exp = hidden.unsqueeze(1).repeat(1, S, 1)  # [B,S,dec_dim]
        energy = torch.tanh(self.attn(torch.cat((hidden_exp, encoder_outputs), dim=2)))  # [B,S,dec_dim]
        scores = self.v(energy).squeeze(-1)  # [B,S]

        scores = scores.masked_fill(~mask, -1e9)
        attn_weights = torch.softmax(scores, dim=1)  # [B,S]

        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # [B,1,enc_dim]
        context = context.squeeze(1)  # [B,enc_dim]
        return context, attn_weights


class Decoder(nn.Module):
    def __init__(self, tgt_vocab_size, emb_dim=TGT_CHAR_EMB,
                 enc_dim=ENC_HIDDEN, dec_dim=DEC_HIDDEN, dropout=DROPOUT, pad_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(tgt_vocab_size, emb_dim, padding_idx=pad_idx)
        self.rnn = nn.LSTM(emb_dim + enc_dim, dec_dim, batch_first=True)
        self.attn = BahdanauAttention(enc_dim, dec_dim)
        self.fc_out = nn.Linear(dec_dim, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_token, hidden, cell, encoder_outputs, mask):
        """
        input_token: [B]
        hidden, cell: [1,B,dec_dim]
        encoder_outputs: [B,S,enc_dim]
        mask: [B,S]
        """
        emb = self.dropout(self.embedding(input_token)).unsqueeze(1)  # [B,1,E]
        dec_hidden = hidden[-1]  # [B,dec_dim]
        context, attn_weights = self.attn(dec_hidden, encoder_outputs, mask)  # [B,enc_dim]
        context = context.unsqueeze(1)  # [B,1,enc_dim]
        rnn_input = torch.cat((emb, context), dim=2)  # [B,1,E+enc_dim]
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        logits = self.fc_out(output.squeeze(1))  # [B, vocab]
        return logits, hidden, cell, attn_weights


class Seq2Seq(nn.Module):
    def __init__(self, encoder: GNNEncoder, decoder: Decoder,
                 tgt_sos_idx: int, tgt_eos_idx: int, tgt_pad_idx: int):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_sos_idx = tgt_sos_idx
        self.tgt_eos_idx = tgt_eos_idx
        self.tgt_pad_idx = tgt_pad_idx

    def forward(self, x, edge_index, batch_vec, tgt_padded, teacher_forcing_ratio=0.5):
        """
        x: [total_nodes]
        edge_index: [2, total_edges]
        batch_vec: [total_nodes]
        tgt_padded: [B, T]
        """
        device = x.device
        B, T = tgt_padded.size()
        encoder_outputs, src_mask = self.encoder(x, edge_index, batch_vec)  # [B,S,H], [B,S]

        # init decoder hidden, cell as zeros
        hidden = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device)
        cell = torch.zeros(1, B, self.decoder.rnn.hidden_size, device=device)

        outputs = torch.zeros(B, T, self.decoder.fc_out.out_features, device=device)

        input_token = tgt_padded[:, 0]  # <sos>
        for t in range(1, T):
            logits, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs, src_mask)
            outputs[:, t, :] = logits
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = logits.argmax(dim=-1)
            input_token = tgt_padded[:, t] if teacher_force else top1

        return outputs  # [B,T,V]