import torch import math import torch.nn as nn from torch.nn import Parameter import torch.nn.functional as F class ArcMarginProduct(nn.Module): r"""Implement of large margin arc distance: : Args: in_features: size of each input sample out_features: size of each output sample s: norm of input feature m: margin cos(theta + m) """ def __init__( self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0, alpha=0.0, device="cuda", ): super(ArcMarginProduct, self).__init__() self.in_features = in_features self.out_features = out_features self.s = s self.m = m self.alpha = alpha self.ls_eps = ls_eps # label smoothing self.device = device self.weight = Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.easy_margin = easy_margin self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m def __repr__(self): return ( "{in_features}, {out_features}, s={s}, m = {m}, " "easy_margin = {easy_margin}, ls_eps = {ls_eps} ".format(**self.__dict__) ) def update_margin(self, epoch): m = self.m + (self.alpha * (epoch + 1)) self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.th = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m print(f"margin updated to : {m}") return None def forward(self, input, label): # --------------------------- cos(theta) & phi(theta) --------------------------- cosine = F.linear(input, F.normalize(self.weight)) cosine = cosine.clamp(-1, 1) sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * self.cos_m - sine * self.sin_m if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: phi = torch.where(cosine > self.th, phi, cosine - self.mm) # --------------------------- convert label to one-hot --------------------------- # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') one_hot = torch.zeros(cosine.size(), device=self.device) one_hot.scatter_(1, label.view(-1, 1).long(), 1) if self.ls_eps > 0: one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- output = (one_hot * phi) + ((1.0 - one_hot) * cosine) output *= self.s return output