Anirban0011's picture
upd
fcd2005
import torch
import numpy as np
from main_folder.code_base.utils import CFG
import torch.nn.functional as F
from functools import reduce
from utils.knn import get_matches, th_matches
K = 51
di = 1792
dt = 1024
dc = di+dt
n_batch = 10
def filter_embeddings(feas, matches, sims):
feas = feas.detach().cpu()
new_feas = feas.clone()
for i in range(feas.shape[0]):
cur_feas = feas[matches[i]]
weights = torch.unsqueeze(torch.Tensor(sims[i]), 1)
new_feas[i] = weights.T@cur_feas
new_feas = F.normalize(new_feas)
return new_feas.to(CFG.device)
def filter_matches(matches, sims, th=1.0, k=3, dist=1e-2):
top_matches = [row[:k] for row in matches]
top_sims = [row[:k] for row in sims]
for i in range(len(matches)):
if len(matches[i]) < k+1:
continue
dist_1 = sims[i][k-2] - sims[i][k-1]
dist_2 = sims[i][k-1] - sims[i][k]
if dist_2 < dist:
continue
if th*dist_1 < dist_2:
matches[i] = top_matches[i]
sims[i] = top_sims[i]
return matches, sims
def union_matches(*lists):
matches = []
for group in zip(*lists):
matches.append(reduce(np.union1d, group).tolist())
return matches
def predict(img_feas, txt_feas):
img_feas, txt_feas = F.normalize(img_feas).to(CFG.device) , F.normalize(txt_feas).to(CFG.device)
comb_feas = F.normalize(torch.cat([img_feas, txt_feas], dim=1)).to(CFG.device)
bs = len(comb_feas) // n_batch
img_matches, img_sims = get_matches(bs, n_batch, img_feas, di, k=K)
text_matches, text_sims = get_matches(bs, n_batch, txt_feas, dt, k=K)
comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K)
img_final, img_sims = th_matches(bs, n_batch, img_matches, img_sims, 0.704)
text_final, text_sims = th_matches(bs, n_batch, text_matches, text_sims, 0.764)
comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.52)
comb_feas = filter_embeddings(comb_feas, comb_final, comb_sims)
comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K)
comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.9)
img_final,_ = filter_matches(img_final, img_sims, 1.1, 4, 2e-2)
text_final,_ = filter_matches(text_final, text_sims, 1.2, 4, 2e-2)
comb_final,_ = filter_matches(comb_final, comb_sims, 1.0, 3, 2e-2)
match_final = union_matches(img_final, text_final, comb_final)
return match_final