| from main_folder.code_base.utils import CFG | |
| import faiss | |
| import torch | |
| def build_faiss(feas, dim): | |
| if CFG.device.type == "cpu": | |
| index = faiss.IndexFlatIP(dim) | |
| else : | |
| res = faiss.StandardGpuResources() | |
| index = faiss.GpuIndexFlatIP(res, dim) | |
| index.add(feas) | |
| return index | |
| def get_batches(bs, n_batch, feas): | |
| batches = [] | |
| for i in range(n_batch): | |
| left = bs * i | |
| right = bs * (i+1) | |
| if i == n_batch - 1: | |
| right = feas.shape[0] | |
| batches.append(feas[left:right,:]) | |
| return batches | |
| def get_matches(bs, n_batch, feas, dim, k=51): | |
| index = build_faiss(feas, dim) | |
| m=[] | |
| s=[] | |
| for batch in get_batches(bs, n_batch, feas): | |
| batch = batch.to(CFG.device) | |
| sims, matches = index.search(batch, k) | |
| m.append(matches) | |
| s.append(sims) | |
| m = torch.cat(m, dim=0).to(torch.int32) | |
| s = torch.cat(s, dim=0) | |
| return m,s | |
| def th_matches(bs, n_batch, matches, sims, th): | |
| matches = get_batches(bs, n_batch, matches) | |
| sims = get_batches(bs, n_batch, sims) | |
| m = [] | |
| s=[] | |
| for (batch_m, batch_s) in zip(matches, sims): | |
| batch_m = batch_m.cpu().numpy() | |
| batch_s = batch_s.cpu().numpy() | |
| mask = (batch_s > th) | |
| for row in range(len(mask)): | |
| m.append(batch_m[row][mask[row]].tolist()) | |
| s.append(batch_s[row][mask[row]].tolist()) | |
| return m, s |