|
|
import torch |
|
|
import math |
|
|
import torch.nn as nn |
|
|
from torch.nn import Parameter |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class ArcMarginProduct(nn.Module): |
|
|
r"""Implement of large margin arc distance: : |
|
|
Args: |
|
|
in_features: size of each input sample |
|
|
out_features: size of each output sample |
|
|
s: norm of input feature |
|
|
m: margin |
|
|
cos(theta + m) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
out_features, |
|
|
s=30.0, |
|
|
m=0.50, |
|
|
easy_margin=False, |
|
|
ls_eps=0.0, |
|
|
alpha=0.0, |
|
|
device="cuda", |
|
|
): |
|
|
super(ArcMarginProduct, self).__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
self.s = s |
|
|
self.m = m |
|
|
self.alpha = alpha |
|
|
self.ls_eps = ls_eps |
|
|
self.device = device |
|
|
self.weight = Parameter(torch.FloatTensor(out_features, in_features)) |
|
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
|
|
self.easy_margin = easy_margin |
|
|
self.cos_m = math.cos(m) |
|
|
self.sin_m = math.sin(m) |
|
|
self.th = math.cos(math.pi - m) |
|
|
self.mm = math.sin(math.pi - m) * m |
|
|
|
|
|
def __repr__(self): |
|
|
return ( |
|
|
"{in_features}, {out_features}, s={s}, m = {m}, " |
|
|
"easy_margin = {easy_margin}, ls_eps = {ls_eps} ".format(**self.__dict__) |
|
|
) |
|
|
|
|
|
def update_margin(self, epoch): |
|
|
m = self.m + (self.alpha * (epoch + 1)) |
|
|
self.cos_m = math.cos(m) |
|
|
self.sin_m = math.sin(m) |
|
|
self.th = math.cos(math.pi - m) |
|
|
self.mm = math.sin(math.pi - m) * m |
|
|
print(f"margin updated to : {m}") |
|
|
return None |
|
|
|
|
|
def forward(self, input, label): |
|
|
|
|
|
cosine = F.linear(input, F.normalize(self.weight)) |
|
|
cosine = cosine.clamp(-1, 1) |
|
|
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) |
|
|
phi = cosine * self.cos_m - sine * self.sin_m |
|
|
if self.easy_margin: |
|
|
phi = torch.where(cosine > 0, phi, cosine) |
|
|
else: |
|
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
|
|
|
|
|
|
|
|
one_hot = torch.zeros(cosine.size(), device=self.device) |
|
|
one_hot.scatter_(1, label.view(-1, 1).long(), 1) |
|
|
if self.ls_eps > 0: |
|
|
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features |
|
|
|
|
|
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
|
|
output *= self.s |
|
|
|
|
|
return output |
|
|
|