|
|
import timm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from timm.layers import ScaledStdConv2d, ScaledStdConv2dSame, BatchNormAct2d |
|
|
from main_folder.code_base.utils import ArcMarginProduct, CurricularFace |
|
|
from main_folder.code_base.pipeline.gempool import GeM |
|
|
from main_folder.code_base.pipeline.depthconv import DepthwiseSeparableConv |
|
|
|
|
|
|
|
|
class ImgEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_classes, |
|
|
embed_size=1792, |
|
|
backbone=None, |
|
|
pretrained=True, |
|
|
scale=30.0, |
|
|
margin=0.5, |
|
|
alpha=0.0, |
|
|
final_layer="arcface", |
|
|
device="cuda", |
|
|
permute=False, |
|
|
p=3, |
|
|
): |
|
|
super().__init__() |
|
|
self.backbone = timm.create_model(backbone, pretrained=pretrained) |
|
|
self.embed_size = embed_size |
|
|
self.num_classes = num_classes |
|
|
self.margin = margin |
|
|
self.scale = scale |
|
|
self.device = device |
|
|
self.p = p |
|
|
|
|
|
self.final_conv = nn.Conv2d( |
|
|
self.backbone.num_features, |
|
|
self.embed_size, |
|
|
kernel_size=1, |
|
|
) |
|
|
|
|
|
if final_layer == "arcface": |
|
|
self.final = ArcMarginProduct( |
|
|
in_features=self.embed_size, |
|
|
out_features=self.num_classes, |
|
|
s=self.scale, |
|
|
m=self.margin, |
|
|
alpha=alpha, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
if final_layer == "currface": |
|
|
self.final = CurricularFace( |
|
|
in_features=self.embed_size, |
|
|
out_features=self.num_classes, |
|
|
s=self.scale, |
|
|
m=self.margin, |
|
|
alpha=alpha, |
|
|
) |
|
|
|
|
|
self.gem = GeM(p=self.p) |
|
|
self.bn = nn.BatchNorm1d(self.embed_size) |
|
|
self.permute = permute |
|
|
|
|
|
def forward(self, x, labels=None): |
|
|
features = self.backbone.forward_features(x) |
|
|
if self.permute: |
|
|
features = torch.permute(features, (0, 3, 1, 2)) |
|
|
features = self.final_conv(features) |
|
|
features = self.gem(features) |
|
|
features = features.view(features.size(0), -1) |
|
|
features = self.bn(features) |
|
|
features = F.normalize(features) |
|
|
if labels is not None: |
|
|
return self.final(features, labels) |
|
|
return features |
|
|
|