# src/dataset.py import torch from torch.utils.data import Dataset from .config import MAX_SRC_LEN, MAX_TGT_LEN def encode_src_word(word, src_char2idx): SRC_UNK = src_char2idx[""] ids = [src_char2idx.get(c, SRC_UNK) for c in word] ids = ids[: MAX_SRC_LEN] if not ids: ids = [SRC_UNK] return torch.tensor(ids, dtype=torch.long) def encode_tgt_word(word, tgt_char2idx): TGT_SOS = tgt_char2idx[""] TGT_EOS = tgt_char2idx[""] TGT_UNK = tgt_char2idx[""] ids = [TGT_SOS] ids += [tgt_char2idx.get(c, TGT_UNK) for c in word] ids = ids[: MAX_TGT_LEN - 1] ids.append(TGT_EOS) return torch.tensor(ids, dtype=torch.long) def build_edge_index(num_nodes: int): if num_nodes <= 1: return torch.empty((2, 0), dtype=torch.long) src = [] dst = [] for i in range(num_nodes - 1): src.append(i) dst.append(i + 1) src.append(i + 1) dst.append(i) edge_index = torch.tensor([src, dst], dtype=torch.long) return edge_index class WordTranslitDataset(Dataset): def __init__(self, src_words, tgt_words, src_char2idx, tgt_char2idx): assert len(src_words) == len(tgt_words) self.src_words = src_words self.tgt_words = tgt_words self.src_char2idx = src_char2idx self.tgt_char2idx = tgt_char2idx def __len__(self): return len(self.src_words) def __getitem__(self, idx): sw = self.src_words[idx] tw = self.tgt_words[idx] src_ids = encode_src_word(sw, self.src_char2idx) tgt_ids = encode_tgt_word(tw, self.tgt_char2idx) edge_index = build_edge_index(len(src_ids)) return src_ids, edge_index, tgt_ids def collate_fn(batch): """ batch: list of (src_ids, edge_index, tgt_ids) Returns: x : [total_nodes] edge_idx: [2, total_edges] batch_vec: [total_nodes] (graph index for each node) tgt_padded: [B, max_tgt_len] tgt_lengths: [B] """ src_list, edge_list, tgt_list = zip(*batch) batch_size = len(src_list) # merge graphs all_nodes = [] all_edges = [] batch_vec = [] node_offset = 0 for i, (src_ids, edge_index) in enumerate(zip(src_list, edge_list)): num_nodes = src_ids.size(0) all_nodes.append(src_ids) if edge_index.numel() > 0: all_edges.append(edge_index + node_offset) batch_vec.append(torch.full((num_nodes,), i, dtype=torch.long)) node_offset += num_nodes x = torch.cat(all_nodes, dim=0) # [total_nodes] batch_vec = torch.cat(batch_vec, dim=0) # [total_nodes] if all_edges: edge_idx = torch.cat(all_edges, dim=1) # [2, total_edges] else: edge_idx = torch.empty((2, 0), dtype=torch.long) # pad tgt sequences max_tgt_len = max(t.size(0) for t in tgt_list) tgt_padded = torch.zeros((batch_size, max_tgt_len), dtype=torch.long) tgt_lengths = torch.zeros(batch_size, dtype=torch.long) for i, t in enumerate(tgt_list): L = t.size(0) tgt_padded[i, :L] = t tgt_lengths[i] = L return x, edge_idx, batch_vec, tgt_padded, tgt_lengths