import torch import torch.nn as nn import torch.nn.functional as F import transformers as tfe from transformers import AutoModel, AutoConfig from main_folder.code_base.utils import ArcMarginProduct, CurricularFace class TextEncoder(nn.Module): def __init__( self, num_classes, embed_size=1024, max_seq_length=35, backbone=None, dropout=0.5, scale=30.0, margin=0.5, final_layer="arcface", device="cuda", eval_model=False, alpha=0.0, ): super().__init__() self.backbone_name = backbone if eval_model: self.config = AutoConfig.from_pretrained(backbone) self.backbone = AutoModel.from_config(self.config) else: self.backbone = AutoModel.from_pretrained(backbone) self.out_features = num_classes self.embed_size = embed_size self.scale = scale self.margin = margin self.device = device if final_layer == "arcface": self.final = ArcMarginProduct( in_features=self.embed_size, out_features=self.out_features, s=self.scale, m=self.margin, device=self.device, alpha=alpha, ) if final_layer == "currface": self.final = CurricularFace( in_features=self.embed_size, out_features=self.out_features, s=self.scale, m=self.margin, ) self.fc = nn.Linear(self.backbone.config.hidden_size, self.embed_size) self.pool = nn.AvgPool1d(kernel_size=max_seq_length) self.bn = nn.BatchNorm1d(self.embed_size) def forward(self, input_ids, attention_mask, labels=None): features = self.backbone( input_ids, attention_mask=attention_mask ).last_hidden_state features = self.fc(features) features = features.transpose(1, 2) features = self.pool(features) features = features.view(features.size(0), -1) features = self.bn(features) features = F.normalize(features) if labels is not None: return self.final(features, labels) return features