import torch import numpy as np from main_folder.code_base.utils import CFG import torch.nn.functional as F from functools import reduce from utils.knn import get_matches, th_matches K = 51 di = 1792 dt = 1024 dc = di+dt n_batch = 10 def filter_embeddings(feas, matches, sims): feas = feas.detach().cpu() new_feas = feas.clone() for i in range(feas.shape[0]): cur_feas = feas[matches[i]] weights = torch.unsqueeze(torch.Tensor(sims[i]), 1) new_feas[i] = weights.T@cur_feas new_feas = F.normalize(new_feas) return new_feas.to(CFG.device) def filter_matches(matches, sims, th=1.0, k=3, dist=1e-2): top_matches = [row[:k] for row in matches] top_sims = [row[:k] for row in sims] for i in range(len(matches)): if len(matches[i]) < k+1: continue dist_1 = sims[i][k-2] - sims[i][k-1] dist_2 = sims[i][k-1] - sims[i][k] if dist_2 < dist: continue if th*dist_1 < dist_2: matches[i] = top_matches[i] sims[i] = top_sims[i] return matches, sims def union_matches(*lists): matches = [] for group in zip(*lists): matches.append(reduce(np.union1d, group).tolist()) return matches def predict(img_feas, txt_feas): img_feas, txt_feas = F.normalize(img_feas).to(CFG.device) , F.normalize(txt_feas).to(CFG.device) comb_feas = F.normalize(torch.cat([img_feas, txt_feas], dim=1)).to(CFG.device) bs = len(comb_feas) // n_batch img_matches, img_sims = get_matches(bs, n_batch, img_feas, di, k=K) text_matches, text_sims = get_matches(bs, n_batch, txt_feas, dt, k=K) comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K) img_final, img_sims = th_matches(bs, n_batch, img_matches, img_sims, 0.704) text_final, text_sims = th_matches(bs, n_batch, text_matches, text_sims, 0.764) comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.52) comb_feas = filter_embeddings(comb_feas, comb_final, comb_sims) comb_matches, comb_sims = get_matches(bs, n_batch, comb_feas, dc, k=K) comb_final, comb_sims = th_matches(bs, n_batch, comb_matches, comb_sims, 0.9) img_final,_ = filter_matches(img_final, img_sims, 1.1, 4, 2e-2) text_final,_ = filter_matches(text_final, text_sims, 1.2, 4, 2e-2) comb_final,_ = filter_matches(comb_final, comb_sims, 1.0, 3, 2e-2) match_final = union_matches(img_final, text_final, comb_final) return match_final