orionweller's picture
init
a907241
raw
history blame
5.96 kB
import gradio as gr
import pickle
import numpy as np
import glob
from tqdm import tqdm
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from tevatron.retriever.searcher import FaissFlatSearcher
import logging
import os
import json
import spaces
import ir_datasets
import subprocess
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint"
base_model = "meta-llama/Llama-2-7b-hf"
tokenizer = None
model = None
retriever = None
corpus_lookup = None
queries = None
q_lookup = None
def load_model():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model_instance = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
model = model.merge_and_unload()
model.eval()
model.cuda()
def load_corpus_embeddings(dataset_name):
global retriever, corpus_lookup
corpus_path = f"{dataset_name}/corpus_emb*"
index_files = glob.glob(corpus_path)
logger.info(f'Pattern match found {len(index_files)} files; loading them into index.')
p_reps_0, p_lookup_0 = pickle_load(index_files[0])
retriever = FaissFlatSearcher(p_reps_0)
shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]]
corpus_lookup = []
for p_reps, p_lookup in tqdm(shards, desc='Loading shards into index', total=len(index_files)):
retriever.add(p_reps)
corpus_lookup += p_lookup
def pickle_load(path):
with open(path, 'rb') as f:
reps, lookup = pickle.load(f)
return np.array(reps), lookup
def load_queries(dataset_name):
global queries, q_lookup
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}/test")
queries = []
q_lookup = {}
for query in dataset.queries_iter():
queries.append(query.text)
q_lookup[query.query_id] = query.text
def encode_queries(prefix, postfix):
global queries
input_texts = [f"{prefix}Query: {query} {postfix}".strip() for query in queries]
encoded_embeds = []
batch_size = 32 # Adjust as needed
for start_idx in range(0, len(input_texts), batch_size):
batch_input_texts = input_texts[start_idx: start_idx + batch_size]
inputs = tokenizer(batch_input_texts, padding=True, truncation=True, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
embeds = outputs.last_hidden_state[:, 0, :] # Use [CLS] token embedding
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
def search_queries(q_reps, depth=1000):
all_scores, all_indices = retriever.search(q_reps, depth)
psg_indices = [[str(corpus_lookup[x]) for x in q_dd] for q_dd in all_indices]
return all_scores, np.array(psg_indices)
def write_ranking(corpus_indices, corpus_scores, ranking_save_file):
with open(ranking_save_file, 'w') as f:
for qid, q_doc_scores, q_doc_indices in zip(q_lookup.keys(), corpus_scores, corpus_indices):
score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)]
score_list = sorted(score_list, key=lambda x: x[0], reverse=True)
for rank, (s, idx) in enumerate(score_list, 1):
f.write(f'{qid} Q0 {idx} {rank} {s} pyserini\n')
def evaluate_with_subprocess(dataset, ranking_file):
# Convert to TREC format
trec_file = f"rank.{dataset}.trec"
convert_cmd = [
"python", "-m", "tevatron.utils.format.convert_result_to_trec",
"--input", ranking_file,
"--output", trec_file,
"--remove_query"
]
subprocess.run(convert_cmd, check=True)
# Evaluate using trec_eval
eval_cmd = [
"python", "-m", "pyserini.eval.trec_eval",
"-c", "-mrecall.100", "-mndcg_cut.10",
f"beir-v1.0.0-{dataset}-test", trec_file
]
result = subprocess.run(eval_cmd, capture_output=True, text=True, check=True)
# Parse the output
lines = result.stdout.strip().split('\n')
ndcg_10 = float(lines[0].split()[-1])
recall_100 = float(lines[1].split()[-1])
# Clean up temporary files
os.remove(ranking_file)
os.remove(trec_file)
return f"nDCG@10: {ndcg_10:.4f}, Recall@100: {recall_100:.4f}"
@spaces.GPU
def run_evaluation(dataset, prefix, postfix):
global queries, q_lookup
# Load corpus embeddings and queries if not already loaded
if retriever is None or queries is None:
load_corpus_embeddings(dataset)
load_queries(dataset)
# Encode queries
q_reps = encode_queries(prefix, postfix)
# Search
all_scores, psg_indices = search_queries(q_reps)
# Write ranking
ranking_file = f"temp_ranking_{dataset}.txt"
write_ranking(psg_indices, all_scores, ranking_file)
# Evaluate
results = evaluate_with_subprocess(dataset, ranking_file)
return results
def gradio_interface(dataset, prefix, postfix):
return run_evaluation(dataset, prefix, postfix)
# Load model
load_model()
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Dropdown(choices=["scifact", "arguana"], label="Dataset"),
gr.Textbox(label="Prefix prompt"),
gr.Textbox(label="Postfix prompt")
],
outputs=gr.Textbox(label="Evaluation Results"),
title="Query Evaluation with Custom Prompts",
description="Select a dataset and enter prefix and postfix prompts to evaluate queries using Pyserini."
)
# Launch the interface
iface.launch()