Spaces:
Runtime error
Runtime error
| import sys | |
| import warnings | |
| print("Warning: This application requires specific library versions. Please ensure you have the correct versions installed.") | |
| import spaces | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import numpy as np | |
| print(f"NumPy version: {np.__version__}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| # Suppress CUDA initialization warning | |
| warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") | |
| # Check for GPU availability | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Model loading and setup | |
| model_name = "jhu-clsp/FollowIR-7B" | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
| except ValueError as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| print("Please ensure you have the correct versions of transformers and sentencepiece installed.") | |
| sys.exit(1) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| token_false_id = tokenizer.get_vocab()["false"] | |
| token_true_id = tokenizer.get_vocab()["true"] | |
| template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices. | |
| Query: {query} | |
| Document: {text} | |
| Relevant (only output one word, either "true" or "false"): [/INST] """ | |
| def check_relevance(query, instruction, passage): | |
| global model | |
| global tokenizer | |
| global template | |
| global token_false_id | |
| global token_true_id | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| model = model.to(device) | |
| full_query = f"{query} {instruction}" | |
| prompt = template.format(query=full_query, text=passage) | |
| tokens = tokenizer( | |
| [prompt], | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| pad_to_multiple_of=None, | |
| ) | |
| for key in tokens: | |
| tokens[key] = tokens[key].to(device) | |
| with torch.no_grad(): | |
| batch_scores = model(**tokens).logits[:, -1, :] | |
| true_vector = batch_scores[:, token_true_id] | |
| false_vector = batch_scores[:, token_false_id] | |
| batch_scores = torch.stack([false_vector, true_vector], dim=1) | |
| batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | |
| score = batch_scores[:, 1].exp().item() | |
| return f"{score:.4f}" | |
| # Example inputs | |
| examples = [ | |
| [ | |
| "What movies were directed by James Cameron?", | |
| "A relevant document would describe any movie that was directed by James Cameron but not any that are co-directed.", | |
| "Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and co-directed by James Cameron and Rick Jaffe." | |
| ], | |
| [ | |
| "What movies were directed by James Cameron?", | |
| "A relevant document would describe any movie that was directed by James Cameron but not any that are co-directed.", | |
| "Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and directed by James Cameron. Rick Jaffe helped write the script." | |
| ] | |
| ] | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Relevance Using Instructions") | |
| gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Query", placeholder="Enter your search query here") | |
| instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria") | |
| passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5) | |
| submit_button = gr.Button("Check Relevance") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Relevance Probability") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[query_input, instruction_input, passage_input], | |
| outputs=output, | |
| fn=check_relevance, | |
| cache_examples=True, | |
| ) | |
| submit_button.click( | |
| check_relevance, | |
| inputs=[query_input, instruction_input, passage_input], | |
| outputs=[output] | |
| ) | |
| if __name__ == "__main__": | |
| if np.__version__.startswith("2."): | |
| print("Error: This application is not compatible with NumPy 2.x. Please downgrade to NumPy < 2.0.0.") | |
| sys.exit(1) | |
| demo.launch() |