| import gc | |
| import torch | |
| from utils.predict import predict | |
| from utils.filterfunc import filter_match_titles | |
| from utils.ckpts import img_ckpt, txt_ckpt | |
| from utils.utilfuncs import gen_data, load_model, return_feas | |
| img_backbone = ["timm/eca_nfnet_l1.ra2_in1k"] | |
| txt_backbone = ["google-bert/bert-base-uncased"] | |
| def clean(): | |
| gc.collect() | |
| def inference(li, lt, IMG_SIZE, | |
| TKN_PATH, | |
| BATCH_SIZE, | |
| num_workers = 4, | |
| ): | |
| dataloader_img, dataloader_txt = gen_data(li, | |
| lt, | |
| IMG_SIZE, | |
| BATCH_SIZE, | |
| TKN_PATH[0], | |
| num_workers) | |
| img_model = [load_model(backbone=img_backbone[i], | |
| ckpt_path=img_ckpt[i], | |
| img=True) | |
| for i in range(len(img_backbone))] | |
| img_feas = torch.cat([return_feas( | |
| img_model[i], | |
| dataloader_img, img=True) | |
| for i in range(len(img_backbone))], dim=1) | |
| txt_model = [load_model(backbone=TKN_PATH[i], ckpt_path=txt_ckpt[i]) | |
| for i in range(len(txt_backbone))] | |
| txt_feas = torch.cat([return_feas( | |
| txt_model[i], | |
| dataloader_txt) | |
| for i in range(len(txt_backbone))], dim=1) | |
| match_final = predict(img_feas=img_feas, | |
| txt_feas=txt_feas) | |
| match_final = filter_match_titles(match_final, title_list=lt) | |
| assert len(match_final == 2) | |
| return set(match_final[0]) == set(match_final[1]) | |