Anirban0011's picture
upd
fcd2005
from main_folder.code_base.utils import CFG
import faiss
import torch
def build_faiss(feas, dim):
if CFG.device.type == "cpu":
index = faiss.IndexFlatIP(dim)
else :
res = faiss.StandardGpuResources()
index = faiss.GpuIndexFlatIP(res, dim)
index.add(feas)
return index
def get_batches(bs, n_batch, feas):
batches = []
for i in range(n_batch):
left = bs * i
right = bs * (i+1)
if i == n_batch - 1:
right = feas.shape[0]
batches.append(feas[left:right,:])
return batches
def get_matches(bs, n_batch, feas, dim, k=51):
index = build_faiss(feas, dim)
m=[]
s=[]
for batch in get_batches(bs, n_batch, feas):
batch = batch.to(CFG.device)
sims, matches = index.search(batch, k)
m.append(matches)
s.append(sims)
m = torch.cat(m, dim=0).to(torch.int32)
s = torch.cat(s, dim=0)
return m,s
def th_matches(bs, n_batch, matches, sims, th):
matches = get_batches(bs, n_batch, matches)
sims = get_batches(bs, n_batch, sims)
m = []
s=[]
for (batch_m, batch_s) in zip(matches, sims):
batch_m = batch_m.cpu().numpy()
batch_s = batch_s.cpu().numpy()
mask = (batch_s > th)
for row in range(len(mask)):
m.append(batch_m[row][mask[row]].tolist())
s.append(batch_s[row][mask[row]].tolist())
return m, s