| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import AutoTokenizer | |
| from main_folder.code_base.utils import clean_text | |
| class SHOPEETextDataset(Dataset): | |
| def __init__(self, df, tokenizer=None, gen_feat_only=False, clean=False): | |
| self.df = df | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| self.only_feat = gen_feat_only | |
| self.clean = clean | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, index): | |
| row = self.df.loc[index] | |
| text = row.title | |
| if self.clean: | |
| text = clean_text(text) | |
| text = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=35, | |
| return_tensors="pt", | |
| ) | |
| input_ids = text["input_ids"][0] | |
| attention_mask = text["attention_mask"][0] | |
| if self.only_feat: | |
| return input_ids, attention_mask | |
| return input_ids, attention_mask, torch.tensor(row.label_group).float() | |