Spaces:
Sleeping
Sleeping
| import torch | |
| import pandas as pd | |
| from torch_geometric.data import InMemoryDataset, Data | |
| from collections import Counter | |
| from tqdm import tqdm | |
| import os.path as osp | |
| import setup | |
| class Vocabulary: | |
| def __init__(self, texts, special_tokens): | |
| self.itos = {i: s for i, s in enumerate(special_tokens)} | |
| self.stoi = {s: i for i, s in self.itos.items()} | |
| char_counts = Counter(c for text in texts for c in text) | |
| for i, (char, _) in enumerate(char_counts.items(), start=len(special_tokens)): | |
| self.itos[i] = char | |
| self.stoi[char] = i | |
| def __len__(self): | |
| return len(self.itos) | |
| class TeluguishDataset(InMemoryDataset): | |
| def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): | |
| # First, call the parent constructor. It will check for the processed file. | |
| # If it doesn't exist, it will run self.process() to create it. | |
| # The internal automatic loading WILL fail here, but that's okay because we handle it next. | |
| super().__init__(root, transform, pre_transform, pre_filter) | |
| # --- THE FIX: We manually load the data with the correct argument --- | |
| # This overrides the faulty internal loading of the parent class. | |
| loaded_data = torch.load(self.processed_paths[0], weights_only=False) | |
| # Unpack the tuple we saved in the process() method | |
| collated_data_tuple, src_vocab, tgt_vocab = loaded_data | |
| self.data, self.slices = collated_data_tuple | |
| self.src_vocab = src_vocab | |
| self.tgt_vocab = tgt_vocab | |
| # ------------------------------------------------------------------- | |
| def raw_file_names(self): | |
| return [setup.DATA_FILE] | |
| def processed_file_names(self): | |
| return ['data.pt'] | |
| def download(self): | |
| pass | |
| def process(self): | |
| df = pd.read_parquet(self.raw_paths[0]) | |
| df.dropna(inplace=True) | |
| df[setup.TELUGUISH_COL] = df[setup.TELUGUISH_COL].str.lower().str.replace(r'[^a-z\s]', '', regex=True).str.strip() | |
| df[setup.TELUGU_COL] = df[setup.TELUGU_COL].str.strip() | |
| df = df[(df[setup.TELUGUISH_COL].str.len() > 0) & (df[setup.TELUGU_COL].str.len() > 0)] | |
| print("Creating vocabularies...") | |
| special_tokens = ['<pad>', '<sos>', '<eos>'] | |
| src_vocab = Vocabulary(df[setup.TELUGUISH_COL], special_tokens) | |
| tgt_vocab = Vocabulary(df[setup.TELUGU_COL], special_tokens) | |
| print("Processing data into graph objects...") | |
| data_list = [] | |
| for _, row in tqdm(df.iterrows(), total=len(df)): | |
| src_text = row[setup.TELUGUISH_COL] | |
| tgt_text = row[setup.TELUGU_COL] | |
| tokens = list(src_text) | |
| node_features = torch.tensor([src_vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long) | |
| edge_list = [] | |
| if len(tokens) > 1: | |
| for i in range(len(tokens) - 1): edge_list.extend([[i, i + 1], [i + 1, i]]) | |
| edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous() | |
| target_seq_tensor = torch.tensor( | |
| [setup.SOS_TOKEN] + [tgt_vocab.stoi[c] for c in tgt_text] + [setup.EOS_TOKEN], dtype=torch.long) | |
| data = Data(x=node_features, edge_index=edge_index, target_seq=target_seq_tensor) | |
| data_list.append(data) | |
| # Collate the data_list and save everything in a tuple | |
| data, slices = self.collate(data_list) | |
| torch.save(((data, slices), src_vocab, tgt_vocab), self.processed_paths[0]) | |
| print("Done!") |