Anirban0011's picture
upd
fcd2005
import torch
import math
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
class ArcMarginProduct(nn.Module):
r"""Implement of large margin arc distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta + m)
"""
def __init__(
self,
in_features,
out_features,
s=30.0,
m=0.50,
easy_margin=False,
ls_eps=0.0,
alpha=0.0,
device="cuda",
):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.alpha = alpha
self.ls_eps = ls_eps # label smoothing
self.device = device
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def __repr__(self):
return (
"{in_features}, {out_features}, s={s}, m = {m}, "
"easy_margin = {easy_margin}, ls_eps = {ls_eps} ".format(**self.__dict__)
)
def update_margin(self, epoch):
m = self.m + (self.alpha * (epoch + 1))
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
print(f"margin updated to : {m}")
return None
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(input, F.normalize(self.weight))
cosine = cosine.clamp(-1, 1)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device=self.device)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output