Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pickle | |
| import numpy as np | |
| import glob | |
| from tqdm import tqdm | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import PeftModel | |
| from tevatron.retriever.searcher import FaissFlatSearcher | |
| import logging | |
| import os | |
| import json | |
| import spaces | |
| import ir_datasets | |
| import subprocess | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint" | |
| base_model = "meta-llama/Llama-2-7b-hf" | |
| tokenizer = None | |
| model = None | |
| retriever = None | |
| corpus_lookup = None | |
| queries = None | |
| q_lookup = None | |
| def load_model(): | |
| global tokenizer, model | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| base_model_instance = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf") | |
| model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| model.cuda() | |
| def load_corpus_embeddings(dataset_name): | |
| global retriever, corpus_lookup | |
| corpus_path = f"{dataset_name}/corpus_emb*" | |
| index_files = glob.glob(corpus_path) | |
| logger.info(f'Pattern match found {len(index_files)} files; loading them into index.') | |
| p_reps_0, p_lookup_0 = pickle_load(index_files[0]) | |
| retriever = FaissFlatSearcher(p_reps_0) | |
| shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]] | |
| corpus_lookup = [] | |
| for p_reps, p_lookup in tqdm(shards, desc='Loading shards into index', total=len(index_files)): | |
| retriever.add(p_reps) | |
| corpus_lookup += p_lookup | |
| def pickle_load(path): | |
| with open(path, 'rb') as f: | |
| reps, lookup = pickle.load(f) | |
| return np.array(reps), lookup | |
| def load_queries(dataset_name): | |
| global queries, q_lookup | |
| dataset = ir_datasets.load(f"beir/{dataset_name.lower()}/test") | |
| queries = [] | |
| q_lookup = {} | |
| for query in dataset.queries_iter(): | |
| queries.append(query.text) | |
| q_lookup[query.query_id] = query.text | |
| def encode_queries(prefix, postfix): | |
| global queries | |
| input_texts = [f"{prefix}Query: {query} {postfix}".strip() for query in queries] | |
| encoded_embeds = [] | |
| batch_size = 32 # Adjust as needed | |
| for start_idx in range(0, len(input_texts), batch_size): | |
| batch_input_texts = input_texts[start_idx: start_idx + batch_size] | |
| inputs = tokenizer(batch_input_texts, padding=True, truncation=True, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeds = outputs.last_hidden_state[:, 0, :] # Use [CLS] token embedding | |
| embeds = F.normalize(embeds, p=2, dim=-1) | |
| encoded_embeds.append(embeds.cpu().numpy()) | |
| return np.concatenate(encoded_embeds, axis=0) | |
| def search_queries(q_reps, depth=1000): | |
| all_scores, all_indices = retriever.search(q_reps, depth) | |
| psg_indices = [[str(corpus_lookup[x]) for x in q_dd] for q_dd in all_indices] | |
| return all_scores, np.array(psg_indices) | |
| def write_ranking(corpus_indices, corpus_scores, ranking_save_file): | |
| with open(ranking_save_file, 'w') as f: | |
| for qid, q_doc_scores, q_doc_indices in zip(q_lookup.keys(), corpus_scores, corpus_indices): | |
| score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)] | |
| score_list = sorted(score_list, key=lambda x: x[0], reverse=True) | |
| for rank, (s, idx) in enumerate(score_list, 1): | |
| f.write(f'{qid} Q0 {idx} {rank} {s} pyserini\n') | |
| def evaluate_with_subprocess(dataset, ranking_file): | |
| # Convert to TREC format | |
| trec_file = f"rank.{dataset}.trec" | |
| convert_cmd = [ | |
| "python", "-m", "tevatron.utils.format.convert_result_to_trec", | |
| "--input", ranking_file, | |
| "--output", trec_file, | |
| "--remove_query" | |
| ] | |
| subprocess.run(convert_cmd, check=True) | |
| # Evaluate using trec_eval | |
| eval_cmd = [ | |
| "python", "-m", "pyserini.eval.trec_eval", | |
| "-c", "-mrecall.100", "-mndcg_cut.10", | |
| f"beir-v1.0.0-{dataset}-test", trec_file | |
| ] | |
| result = subprocess.run(eval_cmd, capture_output=True, text=True, check=True) | |
| # Parse the output | |
| lines = result.stdout.strip().split('\n') | |
| ndcg_10 = float(lines[0].split()[-1]) | |
| recall_100 = float(lines[1].split()[-1]) | |
| # Clean up temporary files | |
| os.remove(ranking_file) | |
| os.remove(trec_file) | |
| return f"nDCG@10: {ndcg_10:.4f}, Recall@100: {recall_100:.4f}" | |
| def run_evaluation(dataset, prefix, postfix): | |
| global queries, q_lookup | |
| # Load corpus embeddings and queries if not already loaded | |
| if retriever is None or queries is None: | |
| load_corpus_embeddings(dataset) | |
| load_queries(dataset) | |
| # Encode queries | |
| q_reps = encode_queries(prefix, postfix) | |
| # Search | |
| all_scores, psg_indices = search_queries(q_reps) | |
| # Write ranking | |
| ranking_file = f"temp_ranking_{dataset}.txt" | |
| write_ranking(psg_indices, all_scores, ranking_file) | |
| # Evaluate | |
| results = evaluate_with_subprocess(dataset, ranking_file) | |
| return results | |
| def gradio_interface(dataset, prefix, postfix): | |
| return run_evaluation(dataset, prefix, postfix) | |
| # Load model | |
| load_model() | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=[ | |
| gr.Dropdown(choices=["scifact", "arguana"], label="Dataset"), | |
| gr.Textbox(label="Prefix prompt"), | |
| gr.Textbox(label="Postfix prompt") | |
| ], | |
| outputs=gr.Textbox(label="Evaluation Results"), | |
| title="Query Evaluation with Custom Prompts", | |
| description="Select a dataset and enter prefix and postfix prompts to evaluate queries using Pyserini." | |
| ) | |
| # Launch the interface | |
| iface.launch() |