Anirban0011's picture
upd
fcd2005
import gc
import torch
from utils.predict import predict
from utils.filterfunc import filter_match_titles
from utils.ckpts import img_ckpt, txt_ckpt
from utils.utilfuncs import gen_data, load_model, return_feas
img_backbone = ["timm/eca_nfnet_l1.ra2_in1k"]
txt_backbone = ["google-bert/bert-base-uncased"]
def clean():
gc.collect()
def inference(li, lt, IMG_SIZE,
TKN_PATH,
BATCH_SIZE,
num_workers = 4,
):
dataloader_img, dataloader_txt = gen_data(li,
lt,
IMG_SIZE,
BATCH_SIZE,
TKN_PATH[0],
num_workers)
img_model = [load_model(backbone=img_backbone[i],
ckpt_path=img_ckpt[i],
img=True)
for i in range(len(img_backbone))]
img_feas = torch.cat([return_feas(
img_model[i],
dataloader_img, img=True)
for i in range(len(img_backbone))], dim=1)
txt_model = [load_model(backbone=TKN_PATH[i], ckpt_path=txt_ckpt[i])
for i in range(len(txt_backbone))]
txt_feas = torch.cat([return_feas(
txt_model[i],
dataloader_txt)
for i in range(len(txt_backbone))], dim=1)
match_final = predict(img_feas=img_feas,
txt_feas=txt_feas)
match_final = filter_match_titles(match_final, title_list=lt)
assert len(match_final == 2)
return set(match_final[0]) == set(match_final[1])