Anirban0011's picture
upd
fcd2005
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers as tfe
from transformers import AutoModel, AutoConfig
from main_folder.code_base.utils import ArcMarginProduct, CurricularFace
class TextEncoder(nn.Module):
def __init__(
self,
num_classes,
embed_size=1024,
max_seq_length=35,
backbone=None,
dropout=0.5,
scale=30.0,
margin=0.5,
final_layer="arcface",
device="cuda",
eval_model=False,
alpha=0.0,
):
super().__init__()
self.backbone_name = backbone
if eval_model:
self.config = AutoConfig.from_pretrained(backbone)
self.backbone = AutoModel.from_config(self.config)
else:
self.backbone = AutoModel.from_pretrained(backbone)
self.out_features = num_classes
self.embed_size = embed_size
self.scale = scale
self.margin = margin
self.device = device
if final_layer == "arcface":
self.final = ArcMarginProduct(
in_features=self.embed_size,
out_features=self.out_features,
s=self.scale,
m=self.margin,
device=self.device,
alpha=alpha,
)
if final_layer == "currface":
self.final = CurricularFace(
in_features=self.embed_size,
out_features=self.out_features,
s=self.scale,
m=self.margin,
)
self.fc = nn.Linear(self.backbone.config.hidden_size, self.embed_size)
self.pool = nn.AvgPool1d(kernel_size=max_seq_length)
self.bn = nn.BatchNorm1d(self.embed_size)
def forward(self, input_ids, attention_mask, labels=None):
features = self.backbone(
input_ids, attention_mask=attention_mask
).last_hidden_state
features = self.fc(features)
features = features.transpose(1, 2)
features = self.pool(features)
features = features.view(features.size(0), -1)
features = self.bn(features)
features = F.normalize(features)
if labels is not None:
return self.final(features, labels)
return features