Transliteration / preprocess.py
Anudeep Tippabathuni
Fixed-3
d1dce40
import torch
import pandas as pd
from torch_geometric.data import Data
from collections import Counter
from tqdm import tqdm
import setup
from dataset import Vocabulary # We'll import the Vocabulary class from our slimmed-down dataset.py
print("Step 1: Loading and Cleaning Raw Data...")
df = pd.read_parquet(setup.DATA_FILE)
df.dropna(inplace=True)
# Basic cleaning
df[setup.TELUGUISH_COL] = df[setup.TELUGUISH_COL].str.lower().str.replace(r'[^a-z\s]', '', regex=True).str.strip()
df[setup.TELUGU_COL] = df[setup.TELUGU_COL].str.strip()
# Filter out any rows that became empty after cleaning
initial_rows = len(df)
df = df[df[setup.TELUGUISH_COL].str.len() > 0]
df = df[df[setup.TELUGU_COL].str.len() > 0]
print(f"Data Cleaning: Kept {len(df)} rows, filtered out {initial_rows - len(df)} empty rows.")
print("\nStep 2: Creating Vocabularies...")
special_tokens = ['<pad>', '<sos>', '<eos>']
src_vocab = Vocabulary(df[setup.TELUGUISH_COL], special_tokens)
tgt_vocab = Vocabulary(df[setup.TELUGU_COL], special_tokens)
print("Vocabularies created.")
print("\nStep 3: Converting Sentences to Graph Objects...")
data_list = []
for _, row in tqdm(df.iterrows(), total=len(df), desc="Preprocessing Data into Graphs"):
src_text = row[setup.TELUGUISH_COL]
tgt_text = row[setup.TELUGU_COL]
tokens = list(src_text)
node_features = torch.tensor([src_vocab.stoi.get(token, setup.PAD_TOKEN) for token in tokens], dtype=torch.long)
edge_list = []
if len(tokens) > 1:
for i in range(len(tokens) - 1):
edge_list.append([i, i + 1])
edge_list.append([i + 1, i])
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
target_seq_tensor = torch.tensor(
[setup.SOS_TOKEN] + [tgt_vocab.stoi[c] for c in tgt_text] + [setup.EOS_TOKEN],
dtype=torch.long
)
data = Data(x=node_features, edge_index=edge_index, target_seq=target_seq_tensor)
data_list.append(data)
print("\nStep 4: Saving Preprocessed Data to File...")
torch.save({
'data_list': data_list,
'src_vocab': src_vocab,
'tgt_vocab': tgt_vocab
}, 'preprocessed_data.pt')
print("\nPreprocessing complete. Data saved to 'preprocessed_data.pt'")