File size: 1,036 Bytes
fcd2005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from main_folder.code_base.utils import clean_text

class SHOPEETextDataset(Dataset):
    def __init__(self, df, tokenizer=None, gen_feat_only=False, clean=False):
        self.df = df
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self.only_feat = gen_feat_only
        self.clean = clean

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.loc[index]
        text = row.title
        if self.clean:
            text = clean_text(text)
        text = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=35,
            return_tensors="pt",
        )
        input_ids = text["input_ids"][0]
        attention_mask = text["attention_mask"][0]
        if self.only_feat:
            return input_ids, attention_mask
        return input_ids, attention_mask, torch.tensor(row.label_group).float()