Anirban0011's picture
upd
fcd2005
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import ScaledStdConv2d, ScaledStdConv2dSame, BatchNormAct2d
from main_folder.code_base.utils import ArcMarginProduct, CurricularFace
from main_folder.code_base.pipeline.gempool import GeM
from main_folder.code_base.pipeline.depthconv import DepthwiseSeparableConv
class ImgEncoder(nn.Module):
def __init__(
self,
num_classes,
embed_size=1792,
backbone=None,
pretrained=True,
scale=30.0,
margin=0.5,
alpha=0.0,
final_layer="arcface",
device="cuda",
permute=False,
p=3,
):
super().__init__()
self.backbone = timm.create_model(backbone, pretrained=pretrained)
self.embed_size = embed_size # embedding size
self.num_classes = num_classes # num classes
self.margin = margin
self.scale = scale
self.device = device
self.p = p
self.final_conv = nn.Conv2d(
self.backbone.num_features,
self.embed_size,
kernel_size=1,
)
if final_layer == "arcface":
self.final = ArcMarginProduct(
in_features=self.embed_size,
out_features=self.num_classes,
s=self.scale,
m=self.margin,
alpha=alpha,
device=self.device,
)
if final_layer == "currface":
self.final = CurricularFace(
in_features=self.embed_size,
out_features=self.num_classes,
s=self.scale,
m=self.margin,
alpha=alpha,
)
self.gem = GeM(p=self.p)
self.bn = nn.BatchNorm1d(self.embed_size)
self.permute = permute
def forward(self, x, labels=None):
features = self.backbone.forward_features(x)
if self.permute:
features = torch.permute(features, (0, 3, 1, 2))
features = self.final_conv(features)
features = self.gem(features)
features = features.view(features.size(0), -1)
features = self.bn(features)
features = F.normalize(features)
if labels is not None:
return self.final(features, labels)
return features