import torch import pandas as pd from torch_geometric.data import InMemoryDataset, Data from collections import Counter from tqdm import tqdm import os.path as osp import setup class Vocabulary: def __init__(self, texts, special_tokens): self.itos = {i: s for i, s in enumerate(special_tokens)} self.stoi = {s: i for i, s in self.itos.items()} char_counts = Counter(c for text in texts for c in text) for i, (char, _) in enumerate(char_counts.items(), start=len(special_tokens)): self.itos[i] = char self.stoi[char] = i def __len__(self): return len(self.itos) class TeluguishDataset(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): # First, call the parent constructor. It will check for the processed file. # If it doesn't exist, it will run self.process() to create it. # The internal automatic loading WILL fail here, but that's okay because we handle it next. super().__init__(root, transform, pre_transform, pre_filter) # --- THE FIX: We manually load the data with the correct argument --- # This overrides the faulty internal loading of the parent class. loaded_data = torch.load(self.processed_paths[0], weights_only=False) # Unpack the tuple we saved in the process() method collated_data_tuple, src_vocab, tgt_vocab = loaded_data self.data, self.slices = collated_data_tuple self.src_vocab = src_vocab self.tgt_vocab = tgt_vocab # ------------------------------------------------------------------- @property def raw_file_names(self): return [setup.DATA_FILE] @property def processed_file_names(self): return ['data.pt'] def download(self): pass def process(self): df = pd.read_parquet(self.raw_paths[0]) df.dropna(inplace=True) df[setup.TELUGUISH_COL] = df[setup.TELUGUISH_COL].str.lower().str.replace(r'[^a-z\s]', '', regex=True).str.strip() df[setup.TELUGU_COL] = df[setup.TELUGU_COL].str.strip() df = df[(df[setup.TELUGUISH_COL].str.len() > 0) & (df[setup.TELUGU_COL].str.len() > 0)] print("Creating vocabularies...") special_tokens = ['', '', ''] src_vocab = Vocabulary(df[setup.TELUGUISH_COL], special_tokens) tgt_vocab = Vocabulary(df[setup.TELUGU_COL], special_tokens) print("Processing data into graph objects...") data_list = [] for _, row in tqdm(df.iterrows(), total=len(df)): src_text = row[setup.TELUGUISH_COL] tgt_text = row[setup.TELUGU_COL] tokens = list(src_text) node_features = torch.tensor([src_vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long) edge_list = [] if len(tokens) > 1: for i in range(len(tokens) - 1): edge_list.extend([[i, i + 1], [i + 1, i]]) edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() target_seq_tensor = torch.tensor( [setup.SOS_TOKEN] + [tgt_vocab.stoi[c] for c in tgt_text] + [setup.EOS_TOKEN], dtype=torch.long) data = Data(x=node_features, edge_index=edge_index, target_seq=target_seq_tensor) data_list.append(data) # Collate the data_list and save everything in a tuple data, slices = self.collate(data_list) torch.save(((data, slices), src_vocab, tgt_vocab), self.processed_paths[0]) print("Done!")