Transliteration / dataset.py
Anudeep Tippabathuni
My firt commit
c856b80
import torch
import pandas as pd
from torch_geometric.data import InMemoryDataset, Data
from collections import Counter
from tqdm import tqdm
import os.path as osp
import setup
class Vocabulary:
def __init__(self, texts, special_tokens):
self.itos = {i: s for i, s in enumerate(special_tokens)}
self.stoi = {s: i for i, s in self.itos.items()}
char_counts = Counter(c for text in texts for c in text)
for i, (char, _) in enumerate(char_counts.items(), start=len(special_tokens)):
self.itos[i] = char
self.stoi[char] = i
def __len__(self):
return len(self.itos)
class TeluguishDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
# First, call the parent constructor. It will check for the processed file.
# If it doesn't exist, it will run self.process() to create it.
# The internal automatic loading WILL fail here, but that's okay because we handle it next.
super().__init__(root, transform, pre_transform, pre_filter)
# --- THE FIX: We manually load the data with the correct argument ---
# This overrides the faulty internal loading of the parent class.
loaded_data = torch.load(self.processed_paths[0], weights_only=False)
# Unpack the tuple we saved in the process() method
collated_data_tuple, src_vocab, tgt_vocab = loaded_data
self.data, self.slices = collated_data_tuple
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab
# -------------------------------------------------------------------
@property
def raw_file_names(self):
return [setup.DATA_FILE]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
pass
def process(self):
df = pd.read_parquet(self.raw_paths[0])
df.dropna(inplace=True)
df[setup.TELUGUISH_COL] = df[setup.TELUGUISH_COL].str.lower().str.replace(r'[^a-z\s]', '', regex=True).str.strip()
df[setup.TELUGU_COL] = df[setup.TELUGU_COL].str.strip()
df = df[(df[setup.TELUGUISH_COL].str.len() > 0) & (df[setup.TELUGU_COL].str.len() > 0)]
print("Creating vocabularies...")
special_tokens = ['<pad>', '<sos>', '<eos>']
src_vocab = Vocabulary(df[setup.TELUGUISH_COL], special_tokens)
tgt_vocab = Vocabulary(df[setup.TELUGU_COL], special_tokens)
print("Processing data into graph objects...")
data_list = []
for _, row in tqdm(df.iterrows(), total=len(df)):
src_text = row[setup.TELUGUISH_COL]
tgt_text = row[setup.TELUGU_COL]
tokens = list(src_text)
node_features = torch.tensor([src_vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long)
edge_list = []
if len(tokens) > 1:
for i in range(len(tokens) - 1): edge_list.extend([[i, i + 1], [i + 1, i]])
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
target_seq_tensor = torch.tensor(
[setup.SOS_TOKEN] + [tgt_vocab.stoi[c] for c in tgt_text] + [setup.EOS_TOKEN], dtype=torch.long)
data = Data(x=node_features, edge_index=edge_index, target_seq=target_seq_tensor)
data_list.append(data)
# Collate the data_list and save everything in a tuple
data, slices = self.collate(data_list)
torch.save(((data, slices), src_vocab, tgt_vocab), self.processed_paths[0])
print("Done!")