TransLiTE / src /dataset.py
Anudeep Tippabathuni
Modified imports
dc29eb1
# src/dataset.py
import torch
from torch.utils.data import Dataset
from .config import MAX_SRC_LEN, MAX_TGT_LEN
def encode_src_word(word, src_char2idx):
SRC_UNK = src_char2idx["<unk>"]
ids = [src_char2idx.get(c, SRC_UNK) for c in word]
ids = ids[: MAX_SRC_LEN]
if not ids:
ids = [SRC_UNK]
return torch.tensor(ids, dtype=torch.long)
def encode_tgt_word(word, tgt_char2idx):
TGT_SOS = tgt_char2idx["<sos>"]
TGT_EOS = tgt_char2idx["<eos>"]
TGT_UNK = tgt_char2idx["<unk>"]
ids = [TGT_SOS]
ids += [tgt_char2idx.get(c, TGT_UNK) for c in word]
ids = ids[: MAX_TGT_LEN - 1]
ids.append(TGT_EOS)
return torch.tensor(ids, dtype=torch.long)
def build_edge_index(num_nodes: int):
if num_nodes <= 1:
return torch.empty((2, 0), dtype=torch.long)
src = []
dst = []
for i in range(num_nodes - 1):
src.append(i)
dst.append(i + 1)
src.append(i + 1)
dst.append(i)
edge_index = torch.tensor([src, dst], dtype=torch.long)
return edge_index
class WordTranslitDataset(Dataset):
def __init__(self, src_words, tgt_words, src_char2idx, tgt_char2idx):
assert len(src_words) == len(tgt_words)
self.src_words = src_words
self.tgt_words = tgt_words
self.src_char2idx = src_char2idx
self.tgt_char2idx = tgt_char2idx
def __len__(self):
return len(self.src_words)
def __getitem__(self, idx):
sw = self.src_words[idx]
tw = self.tgt_words[idx]
src_ids = encode_src_word(sw, self.src_char2idx)
tgt_ids = encode_tgt_word(tw, self.tgt_char2idx)
edge_index = build_edge_index(len(src_ids))
return src_ids, edge_index, tgt_ids
def collate_fn(batch):
"""
batch: list of (src_ids, edge_index, tgt_ids)
Returns:
x : [total_nodes]
edge_idx: [2, total_edges]
batch_vec: [total_nodes] (graph index for each node)
tgt_padded: [B, max_tgt_len]
tgt_lengths: [B]
"""
src_list, edge_list, tgt_list = zip(*batch)
batch_size = len(src_list)
# merge graphs
all_nodes = []
all_edges = []
batch_vec = []
node_offset = 0
for i, (src_ids, edge_index) in enumerate(zip(src_list, edge_list)):
num_nodes = src_ids.size(0)
all_nodes.append(src_ids)
if edge_index.numel() > 0:
all_edges.append(edge_index + node_offset)
batch_vec.append(torch.full((num_nodes,), i, dtype=torch.long))
node_offset += num_nodes
x = torch.cat(all_nodes, dim=0) # [total_nodes]
batch_vec = torch.cat(batch_vec, dim=0) # [total_nodes]
if all_edges:
edge_idx = torch.cat(all_edges, dim=1) # [2, total_edges]
else:
edge_idx = torch.empty((2, 0), dtype=torch.long)
# pad tgt sequences
max_tgt_len = max(t.size(0) for t in tgt_list)
tgt_padded = torch.zeros((batch_size, max_tgt_len), dtype=torch.long)
tgt_lengths = torch.zeros(batch_size, dtype=torch.long)
for i, t in enumerate(tgt_list):
L = t.size(0)
tgt_padded[i, :L] = t
tgt_lengths[i] = L
return x, edge_idx, batch_vec, tgt_padded, tgt_lengths