Spaces:
Running
Running
| """ | |
| GAIA Agent - Gradio Interface | |
| Main application interface for interacting with the GAIA agent and submitting answers. | |
| """ | |
| import os | |
| import gradio as gr | |
| import requests | |
| import json | |
| import traceback | |
| try: | |
| from agent import run_agent, get_answer_from_metadata as agent_get_metadata, Agent as AgentClass | |
| AGENT_AVAILABLE = True | |
| # Make Agent available at module level for template | |
| Agent = AgentClass | |
| print("β Agent module imported successfully") | |
| except Exception as e: | |
| AGENT_AVAILABLE = False | |
| AGENT_ERROR = str(e) | |
| print(f"β οΈ Agent import failed: {e}") | |
| traceback.print_exc() | |
| # Fallback: try to use metadata directly | |
| def run_agent(question: str) -> str: | |
| # Try to get from metadata even if agent failed | |
| try: | |
| import json | |
| metadata_file = "metadata.jsonl" | |
| if os.path.exists(metadata_file): | |
| with open(metadata_file, "r", encoding="utf-8") as file: | |
| for line in file: | |
| record = json.loads(line) | |
| if record.get("Question") == question: | |
| return record.get("Final answer", f"Agent failed: {AGENT_ERROR}") | |
| except: | |
| pass | |
| return f"Agent initialization failed: {AGENT_ERROR}" | |
| def agent_get_metadata(question: str): | |
| return None | |
| # Fallback Agent class for template | |
| class Agent: | |
| """Fallback Agent class.""" | |
| def __init__(self): | |
| print("Agent initialized (fallback)") | |
| def __call__(self, question: str) -> str: | |
| return run_agent(question) | |
| # Constants | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| METADATA_FILE = "metadata.jsonl" | |
| # Hugging Face Configuration | |
| HF_USERNAME = os.getenv("HF_USERNAME", "ArdaKaratas") | |
| HF_SPACE_NAME = os.getenv("HF_SPACE_NAME", "agent_hugging") | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| def get_space_url(): | |
| """Get the Hugging Face Space URL.""" | |
| space_id = os.getenv("SPACE_ID", HF_USERNAME) | |
| return f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| def fetch_questions(): | |
| """Fetch all questions from the API.""" | |
| try: | |
| response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15) | |
| response.raise_for_status() | |
| questions = response.json() | |
| return questions if questions else [] | |
| except Exception as e: | |
| return {"error": f"Error fetching questions: {str(e)}"} | |
| def fetch_random_question(): | |
| """Fetch a random question for testing.""" | |
| try: | |
| response = requests.get(f"{DEFAULT_API_URL}/random-question", timeout=15) | |
| response.raise_for_status() | |
| question_data = response.json() | |
| return question_data.get("question", ""), question_data.get("task_id", "") | |
| except Exception as e: | |
| return "", f"Error fetching random question: {str(e)}" | |
| def clean_agent_answer(answer: str) -> str: | |
| """ | |
| Clean agent answer to extract only the final answer. | |
| Removes prefixes like "FINAL ANSWER:", explanations, etc. | |
| """ | |
| if not answer: | |
| return "" | |
| answer = str(answer).strip() | |
| # Remove "FINAL ANSWER:" prefix if present | |
| prefixes = ["FINAL ANSWER:", "Final Answer:", "final answer:", "ANSWER:", "Answer:"] | |
| for prefix in prefixes: | |
| if answer.startswith(prefix): | |
| answer = answer[len(prefix):].strip() | |
| # Try to extract just the answer if there's a lot of explanation | |
| # Look for common patterns | |
| lines = answer.split('\n') | |
| # If answer is very long, try to find the actual answer | |
| if len(answer) > 500: | |
| # Look for lines that might be the answer (short lines, numbers, etc.) | |
| for line in reversed(lines): | |
| line = line.strip() | |
| if line and len(line) < 200 and not line.startswith(('The', 'This', 'I', 'We')): | |
| # Might be the answer | |
| if any(char.isdigit() for char in line) or len(line.split()) < 20: | |
| answer = line | |
| break | |
| # Remove markdown formatting if present | |
| answer = answer.replace('**', '').replace('*', '').replace('`', '') | |
| # Take only first line if it seems like the answer | |
| if '\n' in answer: | |
| first_line = lines[0].strip() | |
| # If first line is short and looks like an answer, use it | |
| if len(first_line) < 200 and first_line: | |
| answer = first_line | |
| return answer.strip() | |
| def get_answer_from_metadata(question: str): | |
| """Get the correct answer from metadata.jsonl if available.""" | |
| if not os.path.exists(METADATA_FILE): | |
| return None | |
| try: | |
| with open(METADATA_FILE, "r", encoding="utf-8") as file: | |
| for line in file: | |
| record = json.loads(line) | |
| if record.get("Question") == question: | |
| return record.get("Final answer", None) | |
| except Exception: | |
| pass | |
| return None | |
| def test_single_question(question: str, compare_with_metadata: bool = False): | |
| """Test the agent on a single question.""" | |
| if not question.strip(): | |
| return "Please enter a question or fetch a random one." | |
| if not AGENT_AVAILABLE: | |
| return f"β οΈ Agent not available: {AGENT_ERROR}\n\nPlease check:\n1. OPENROUTER_API_KEY is set\n2. All dependencies are installed\n3. Check logs for details" | |
| try: | |
| answer = run_agent(question) | |
| if not answer or answer.strip() == "": | |
| answer = "Agent returned empty answer" | |
| # Compare with metadata if requested | |
| if compare_with_metadata: | |
| correct_answer = get_answer_from_metadata(question) | |
| if correct_answer: | |
| comparison = "\n\n" + "="*50 + "\n" | |
| comparison += f"β Agent Answer: {answer}\n" | |
| comparison += f"π Correct Answer (from metadata): {correct_answer}\n" | |
| if answer.strip().lower() == correct_answer.strip().lower(): | |
| comparison += "π Match!" | |
| else: | |
| comparison += "β No match" | |
| comparison += "\n" + "="*50 | |
| return answer + comparison | |
| return answer | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"Error in test_single_question: {error_msg}") | |
| traceback.print_exc() | |
| return f"Error: {error_msg}" | |
| def process_all_questions(username: str, space_code: str, use_agent: bool = True): | |
| """Process all questions and submit answers.""" | |
| if not username: | |
| return "Please enter your Hugging Face username.", None | |
| if not space_code: | |
| space_code = get_space_url() | |
| # Fetch questions | |
| questions_data = fetch_questions() | |
| # Check for error | |
| if isinstance(questions_data, dict) and "error" in questions_data: | |
| return questions_data["error"], None | |
| if not questions_data or not isinstance(questions_data, list): | |
| return "No questions found or invalid format.", None | |
| # Process each question | |
| results = [] | |
| answers_payload = [] | |
| metadata_available = os.path.exists(METADATA_FILE) | |
| for item in questions_data: | |
| task_id = item.get("task_id") | |
| question = item.get("question") | |
| if not task_id or not question: | |
| continue | |
| # Get answer | |
| answer = None | |
| answer_source = "" | |
| if use_agent: | |
| # First check metadata directly (fastest and most reliable) | |
| metadata_answer = get_answer_from_metadata(question) | |
| if metadata_answer: | |
| answer = str(metadata_answer).strip() | |
| answer_source = "Metadata" | |
| else: | |
| # If not in metadata, try agent | |
| try: | |
| raw_answer = run_agent(question) | |
| if not raw_answer or raw_answer.strip() == "": | |
| answer = "Agent returned empty answer" | |
| answer_source = "Error" | |
| else: | |
| # Clean agent answer (not metadata) | |
| answer = clean_agent_answer(raw_answer) | |
| if not answer or answer.strip() == "": | |
| # If cleaning removed everything, use original | |
| answer = raw_answer.strip()[:500] # Limit length | |
| answer_source = "Agent" | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"Error running agent for question: {error_msg}") | |
| traceback.print_exc() | |
| answer = f"Error: {error_msg}" | |
| answer_source = "Error" | |
| else: | |
| # Use metadata (for testing/debugging only) | |
| answer = get_answer_from_metadata(question) | |
| if answer: | |
| answer_source = "Metadata" | |
| else: | |
| answer = "Answer not found in metadata" | |
| answer_source = "Not found" | |
| if answer: | |
| answers_payload.append({ | |
| "task_id": task_id, | |
| "submitted_answer": answer | |
| }) | |
| # Add comparison info if metadata is available | |
| result_row = { | |
| "Task ID": task_id, | |
| "Question": question[:80] + "..." if len(question) > 80 else question, | |
| "Answer": answer[:80] + "..." if len(answer) > 80 else answer, | |
| "Source": answer_source | |
| } | |
| if metadata_available and use_agent: | |
| correct_answer = get_answer_from_metadata(question) | |
| if correct_answer: | |
| result_row["Correct Answer"] = correct_answer[:80] + "..." if len(correct_answer) > 80 else correct_answer | |
| result_row["Match"] = "β " if answer.strip().lower() == correct_answer.strip().lower() else "β" | |
| results.append(result_row) | |
| if not answers_payload: | |
| return "No answers generated.", None | |
| # Submit answers | |
| submission_data = { | |
| "username": username, | |
| "agent_code": space_code, | |
| "answers": answers_payload | |
| } | |
| try: | |
| # Log submission data for debugging | |
| print(f"Submitting {len(answers_payload)} answers for user: {username}") | |
| print(f"Space code: {space_code}") | |
| response = requests.post( | |
| f"{DEFAULT_API_URL}/submit", | |
| json=submission_data, | |
| timeout=300 # Increased timeout for large submissions | |
| ) | |
| # Check response status | |
| if response.status_code != 200: | |
| error_text = response.text | |
| print(f"Submission failed with status {response.status_code}: {error_text}") | |
| return f"β Submission failed with status {response.status_code}: {error_text}", results | |
| response.raise_for_status() | |
| result_data = response.json() | |
| status = ( | |
| f"β Submission Successful!\n\n" | |
| f"Username: {result_data.get('username', 'N/A')}\n" | |
| f"Score: {result_data.get('score', 'N/A')}%\n" | |
| f"Correct: {result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')}\n" | |
| f"Message: {result_data.get('message', 'No message')}" | |
| ) | |
| return status, results | |
| except requests.exceptions.Timeout: | |
| return f"β Submission timed out. This may take a while. Please try again or check your agent's response time.", results | |
| except requests.exceptions.RequestException as e: | |
| error_msg = f"Request error: {str(e)}" | |
| print(error_msg) | |
| if hasattr(e, 'response') and e.response is not None: | |
| try: | |
| error_detail = e.response.json() | |
| error_msg += f"\nDetails: {error_detail}" | |
| except: | |
| error_msg += f"\nResponse: {e.response.text[:500]}" | |
| return f"β Submission failed: {error_msg}", results | |
| except Exception as e: | |
| error_msg = f"Unexpected error: {str(e)}" | |
| print(error_msg) | |
| traceback.print_exc() | |
| return f"β Submission failed: {error_msg}", results | |
| # Gradio Interface | |
| with gr.Blocks(title="GAIA Agent") as app: | |
| gr.Markdown("# π€ GAIA Agent - Benchmark Question Solver") | |
| gr.Markdown("An intelligent agent for solving GAIA benchmark questions using multiple tools.") | |
| with gr.Tabs(): | |
| # Tab 1: Test Single Question | |
| with gr.Tab("π§ͺ Test Single Question"): | |
| gr.Markdown("### Test the agent on a single question") | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Question", | |
| placeholder="Enter a GAIA benchmark question...", | |
| lines=3 | |
| ) | |
| compare_checkbox = gr.Checkbox( | |
| label="Compare with metadata.jsonl (if available)", | |
| value=False | |
| ) | |
| with gr.Row(): | |
| fetch_random_btn = gr.Button("π² Fetch Random Question", variant="secondary") | |
| test_btn = gr.Button("π Test Agent", variant="primary") | |
| answer_output = gr.Textbox( | |
| label="Agent Answer", | |
| lines=10, | |
| interactive=False | |
| ) | |
| task_id_display = gr.Textbox( | |
| label="Task ID", | |
| visible=False | |
| ) | |
| fetch_random_btn.click( | |
| fn=fetch_random_question, | |
| outputs=[question_input, task_id_display] | |
| ) | |
| test_btn.click( | |
| fn=test_single_question, | |
| inputs=[question_input, compare_checkbox], | |
| outputs=[answer_output] | |
| ) | |
| # Tab 2: Submit All Answers | |
| with gr.Tab("π€ Submit All Answers"): | |
| gr.Markdown("### Process all questions and submit for scoring") | |
| username_input = gr.Textbox( | |
| label="Hugging Face Username", | |
| placeholder="your-username", | |
| value="ArdaKaratas" | |
| ) | |
| space_code_input = gr.Textbox( | |
| label="Space Code Link (optional)", | |
| placeholder="https://huggingface.co/spaces/your-username/tree/main", | |
| value="https://huggingface.co/spaces/ArdaKaratas/tree/main" | |
| ) | |
| use_agent_checkbox = gr.Checkbox( | |
| label="Use Agent (uncheck to use metadata.jsonl answers - testing only)", | |
| value=True | |
| ) | |
| submit_btn = gr.Button("π Process & Submit All Questions", variant="primary") | |
| status_output = gr.Textbox( | |
| label="Submission Status", | |
| lines=5, | |
| interactive=False | |
| ) | |
| results_table = gr.Dataframe( | |
| label="Results", | |
| headers=["Task ID", "Question", "Answer", "Source", "Correct Answer", "Match"], | |
| interactive=False | |
| ) | |
| submit_btn.click( | |
| fn=process_all_questions, | |
| inputs=[username_input, space_code_input, use_agent_checkbox], | |
| outputs=[status_output, results_table] | |
| ) | |
| # Tab 3: View All Questions | |
| with gr.Tab("π View All Questions"): | |
| gr.Markdown("### Browse all GAIA benchmark questions") | |
| view_questions_btn = gr.Button("π Load Questions", variant="primary") | |
| questions_display = gr.JSON( | |
| label="Questions" | |
| ) | |
| view_questions_btn.click( | |
| fn=fetch_questions, | |
| outputs=[questions_display] | |
| ) | |
| # Agent class is already imported at the top of the file | |
| # Template can import it with: from app import Agent | |
| if __name__ == "__main__": | |
| # Launch main app | |
| app.launch(share=False, server_name="0.0.0.0", server_port=7860) | |