|
|
|
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from .config import MAX_SRC_LEN, MAX_TGT_LEN |
|
|
|
|
|
|
|
|
def encode_src_word(word, src_char2idx): |
|
|
SRC_UNK = src_char2idx["<unk>"] |
|
|
ids = [src_char2idx.get(c, SRC_UNK) for c in word] |
|
|
ids = ids[: MAX_SRC_LEN] |
|
|
if not ids: |
|
|
ids = [SRC_UNK] |
|
|
return torch.tensor(ids, dtype=torch.long) |
|
|
|
|
|
|
|
|
def encode_tgt_word(word, tgt_char2idx): |
|
|
TGT_SOS = tgt_char2idx["<sos>"] |
|
|
TGT_EOS = tgt_char2idx["<eos>"] |
|
|
TGT_UNK = tgt_char2idx["<unk>"] |
|
|
|
|
|
ids = [TGT_SOS] |
|
|
ids += [tgt_char2idx.get(c, TGT_UNK) for c in word] |
|
|
ids = ids[: MAX_TGT_LEN - 1] |
|
|
ids.append(TGT_EOS) |
|
|
return torch.tensor(ids, dtype=torch.long) |
|
|
|
|
|
|
|
|
def build_edge_index(num_nodes: int): |
|
|
if num_nodes <= 1: |
|
|
return torch.empty((2, 0), dtype=torch.long) |
|
|
src = [] |
|
|
dst = [] |
|
|
for i in range(num_nodes - 1): |
|
|
src.append(i) |
|
|
dst.append(i + 1) |
|
|
src.append(i + 1) |
|
|
dst.append(i) |
|
|
edge_index = torch.tensor([src, dst], dtype=torch.long) |
|
|
return edge_index |
|
|
|
|
|
|
|
|
class WordTranslitDataset(Dataset): |
|
|
def __init__(self, src_words, tgt_words, src_char2idx, tgt_char2idx): |
|
|
assert len(src_words) == len(tgt_words) |
|
|
self.src_words = src_words |
|
|
self.tgt_words = tgt_words |
|
|
self.src_char2idx = src_char2idx |
|
|
self.tgt_char2idx = tgt_char2idx |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.src_words) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sw = self.src_words[idx] |
|
|
tw = self.tgt_words[idx] |
|
|
|
|
|
src_ids = encode_src_word(sw, self.src_char2idx) |
|
|
tgt_ids = encode_tgt_word(tw, self.tgt_char2idx) |
|
|
|
|
|
edge_index = build_edge_index(len(src_ids)) |
|
|
|
|
|
return src_ids, edge_index, tgt_ids |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
""" |
|
|
batch: list of (src_ids, edge_index, tgt_ids) |
|
|
Returns: |
|
|
x : [total_nodes] |
|
|
edge_idx: [2, total_edges] |
|
|
batch_vec: [total_nodes] (graph index for each node) |
|
|
tgt_padded: [B, max_tgt_len] |
|
|
tgt_lengths: [B] |
|
|
""" |
|
|
src_list, edge_list, tgt_list = zip(*batch) |
|
|
batch_size = len(src_list) |
|
|
|
|
|
|
|
|
all_nodes = [] |
|
|
all_edges = [] |
|
|
batch_vec = [] |
|
|
node_offset = 0 |
|
|
|
|
|
for i, (src_ids, edge_index) in enumerate(zip(src_list, edge_list)): |
|
|
num_nodes = src_ids.size(0) |
|
|
all_nodes.append(src_ids) |
|
|
if edge_index.numel() > 0: |
|
|
all_edges.append(edge_index + node_offset) |
|
|
batch_vec.append(torch.full((num_nodes,), i, dtype=torch.long)) |
|
|
node_offset += num_nodes |
|
|
|
|
|
x = torch.cat(all_nodes, dim=0) |
|
|
batch_vec = torch.cat(batch_vec, dim=0) |
|
|
|
|
|
if all_edges: |
|
|
edge_idx = torch.cat(all_edges, dim=1) |
|
|
else: |
|
|
edge_idx = torch.empty((2, 0), dtype=torch.long) |
|
|
|
|
|
|
|
|
max_tgt_len = max(t.size(0) for t in tgt_list) |
|
|
tgt_padded = torch.zeros((batch_size, max_tgt_len), dtype=torch.long) |
|
|
tgt_lengths = torch.zeros(batch_size, dtype=torch.long) |
|
|
for i, t in enumerate(tgt_list): |
|
|
L = t.size(0) |
|
|
tgt_padded[i, :L] = t |
|
|
tgt_lengths[i] = L |
|
|
|
|
|
return x, edge_idx, batch_vec, tgt_padded, tgt_lengths |
|
|
|