| import torch | |
| from functools import partial | |
| from colbert.ranking.index_part import IndexPart | |
| from colbert.ranking.faiss_index import FaissIndex | |
| from colbert.utils.utils import flatten, zipstar | |
| class Ranker(): | |
| def __init__(self, args, inference, faiss_depth=1024): | |
| self.inference = inference | |
| self.faiss_depth = faiss_depth | |
| if faiss_depth is not None: | |
| self.faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, part_range=args.part_range) | |
| self.retrieve = partial(self.faiss_index.retrieve, self.faiss_depth) | |
| self.index = IndexPart(args.index_path, dim=inference.colbert.dim, part_range=args.part_range, verbose=True) | |
| def encode(self, queries): | |
| assert type(queries) in [list, tuple], type(queries) | |
| Q = self.inference.queryFromText(queries, bsize=512 if len(queries) > 512 else None) | |
| return Q | |
| def rank(self, Q, pids=None): | |
| pids = self.retrieve(Q, verbose=False)[0] if pids is None else pids | |
| assert type(pids) in [list, tuple], type(pids) | |
| assert Q.size(0) == 1, (len(pids), Q.size()) | |
| assert all(type(pid) is int for pid in pids) | |
| scores = [] | |
| if len(pids) > 0: | |
| Q = Q.permute(0, 2, 1) | |
| scores = self.index.rank(Q, pids) | |
| scores_sorter = torch.tensor(scores).sort(descending=True) | |
| pids, scores = torch.tensor(pids)[scores_sorter.indices].tolist(), scores_sorter.values.tolist() | |
| return pids, scores | |