""" FARA Backend Server for HuggingFace Space Provides WebSocket communication and REST API for the React frontend """ import asyncio import base64 import logging import os # Import FARA components import sys import tempfile import uuid from datetime import datetime from typing import Dict, Optional import httpx from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from playwright._impl._errors import TargetClosedError sys.path.insert(0, "/app") from fara import FaraAgent from fara.browser.browser_bb import BrowserBB # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Modal trace storage configuration MODAL_TRACE_STORAGE_URL = os.environ.get("MODAL_TRACE_STORAGE_URL", "") MODAL_TOKEN_ID = os.environ.get("MODAL_TOKEN_ID", "") MODAL_TOKEN_SECRET = os.environ.get("MODAL_TOKEN_SECRET", "") # Modal vLLM endpoint configuration (from environment variables for HF Spaces) # Includes proxy auth headers for authenticated Modal endpoints ENDPOINT_CONFIG = { "model": os.environ.get("FARA_MODEL_NAME", "microsoft/Fara-7B"), "base_url": os.environ.get("FARA_ENDPOINT_URL"), "api_key": os.environ.get("FARA_API_KEY", "not-needed"), "default_headers": { "Modal-Key": MODAL_TOKEN_ID, "Modal-Secret": MODAL_TOKEN_SECRET, } if MODAL_TOKEN_ID and MODAL_TOKEN_SECRET else None, } # Available models (for the frontend dropdown) AVAILABLE_MODELS = ["microsoft/Fara-7B"] app = FastAPI(title="FARA Backend") # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Store active connections and their sessions active_connections: Dict[str, WebSocket] = {} active_sessions: Dict[str, "FaraSession"] = {} class FaraSession: """Manages a single FARA agent session""" def __init__(self, trace_id: str, websocket: WebSocket): self.trace_id = trace_id self.websocket = websocket self.agent: Optional[FaraAgent] = None self.browser_manager: Optional[BrowserBB] = None self.screenshots_dir: Optional[str] = None self.is_running = False self.should_stop = False self.step_count = 0 self.start_time: Optional[datetime] = None self.total_input_tokens = 0 self.total_output_tokens = 0 async def initialize(self, start_page: str = "https://www.bing.com/"): """Initialize the browser and agent""" # Create temp directory for screenshots self.screenshots_dir = tempfile.mkdtemp(prefix="fara_screenshots_") # Initialize browser manager (headless for HF Space) self.browser_manager = BrowserBB( headless=True, viewport_height=900, viewport_width=1440, page_script_path=None, browser_channel="chromium", browser_data_dir=None, downloads_folder=self.screenshots_dir, to_resize_viewport=True, single_tab_mode=True, animate_actions=False, use_browser_base=False, logger=logger, ) self.agent = FaraAgent( browser_manager=self.browser_manager, client_config=ENDPOINT_CONFIG, start_page=start_page, downloads_folder=self.screenshots_dir, save_screenshots=True, max_rounds=50, ) await self.agent.initialize() return True async def send_event(self, event: dict): """Send event to the connected WebSocket""" try: await self.websocket.send_json(event) except Exception as e: logger.error(f"Error sending event: {e}") async def get_screenshot_base64(self) -> Optional[str]: """Get the current browser screenshot as base64""" if self.agent: try: # Get the current active page from the browser context page = self._get_active_page() if page: screenshot_bytes = ( await self.agent._playwright_controller.get_screenshot(page) ) return f"data:image/png;base64,{base64.b64encode(screenshot_bytes).decode()}" except TargetClosedError: logger.warning( "Page closed while getting screenshot, attempting recovery..." ) page = self._get_active_page() if page: try: screenshot_bytes = ( await self.agent._playwright_controller.get_screenshot(page) ) return f"data:image/png;base64,{base64.b64encode(screenshot_bytes).decode()}" except Exception as e: logger.error(f"Recovery screenshot failed: {e}") except Exception as e: logger.error(f"Error getting screenshot: {e}") return None def _get_active_page(self): """Get the currently active page from the browser context""" if ( self.agent and self.agent.browser_manager and self.agent.browser_manager._context ): pages = self.agent.browser_manager._context.pages if pages: # Return the last (most recent) page, or the one marked as active return pages[-1] return self.agent._page if self.agent else None async def run_task(self, instruction: str, model_id: str): """Run a task and stream results via WebSocket""" self.is_running = True self.should_stop = False self.step_count = 0 self.start_time = datetime.now() self.total_input_tokens = 0 self.total_output_tokens = 0 try: # Send agent_start event await self.send_event( { "type": "agent_start", "agentTrace": { "id": self.trace_id, "instruction": instruction, "modelId": model_id, "timestamp": self.start_time.isoformat(), "isRunning": True, "traceMetadata": { "traceId": self.trace_id, "inputTokensUsed": 0, "outputTokensUsed": 0, "duration": 0, "numberOfSteps": 0, "maxSteps": 50, "completed": False, }, }, } ) # Initialize agent await self.initialize() # Get initial screenshot initial_screenshot = await self.get_screenshot_base64() # Run the agent with custom loop to stream progress await self._run_agent_with_streaming(instruction) except Exception as e: logger.exception("Error running agent task") await self.send_event({"type": "agent_error", "error": str(e)}) finally: self.is_running = False await self.close() async def _run_agent_with_streaming(self, user_message: str): """Run the agent and stream each step to the frontend""" agent = self.agent # Initialize if not already done await agent.initialize() assert agent._page is not None, "Page should be initialized" # Get initial screenshot scaled_screenshot = await agent._get_scaled_screenshot() if agent.save_screenshots: await agent._playwright_controller.get_screenshot( agent._page, path=os.path.join( agent.downloads_folder, f"screenshot{agent._num_actions}.png" ), ) # Add user message to chat history from fara.types import ImageObj, UserMessage agent._chat_history.append( UserMessage( content=[ImageObj.from_pil(scaled_screenshot), user_message], is_original=True, ) ) final_answer = "" is_stop_action = False for i in range(agent.max_rounds): if self.should_stop: # User requested stop await self.send_event( { "type": "agent_complete", "traceMetadata": self._get_metadata(), "final_state": "stopped", } ) return is_first_round = i == 0 step_start_time = datetime.now() # Wait for captcha if needed if not agent.browser_manager._captcha_event.is_set(): logger.info("Waiting 60s for captcha to finish...") captcha_solved = await agent.wait_for_captcha_with_timeout(60) if ( not captcha_solved and not agent.browser_manager._captcha_event.is_set() ): raise RuntimeError("Captcha timed out") try: # Generate model response function_call, raw_response = await agent.generate_model_call( is_first_round, scaled_screenshot if is_first_round else None ) # Parse response thoughts, action_dict = agent._parse_thoughts_and_action(raw_response) action_args = action_dict.get("arguments", {}) action = action_args["action"] logger.info( f"\nThought #{i + 1}: {thoughts}\nAction #{i + 1}: {action}" ) # Execute action with recovery for page changes try: ( is_stop_action, new_screenshot, action_description, ) = await agent.execute_action(function_call) except TargetClosedError as e: logger.warning( "Page closed during action execution, attempting recovery..." ) # Try to recover the page reference new_page = self._get_active_page() if new_page and new_page != agent._page: logger.info("Recovered with new active page") agent._page = new_page # Wait for the page to stabilize await asyncio.sleep(1) action_description = ( "Action completed (page navigation occurred)" ) is_stop_action = False new_screenshot = None else: raise e # Sync the agent's page reference with the active page active_page = self._get_active_page() if active_page and active_page != agent._page: logger.info("Updating agent page reference to active page") agent._page = active_page # Get screenshot for this step screenshot_base64 = await self.get_screenshot_base64() except TargetClosedError as e: logger.error(f"Unrecoverable page error: {e}") await self.send_event( { "type": "agent_error", "error": f"Browser page closed unexpectedly: {str(e)}", } ) return except Exception as e: logger.exception(f"Error in agent step {i + 1}") await self.send_event({"type": "agent_error", "error": str(e)}) return # Calculate step duration and tokens (estimated) step_duration = (datetime.now() - step_start_time).total_seconds() step_input_tokens = 1000 # Estimated step_output_tokens = len(raw_response) // 4 # Rough estimate self.total_input_tokens += step_input_tokens self.total_output_tokens += step_output_tokens self.step_count += 1 # Create step object step = { "stepId": str(uuid.uuid4()), "traceId": self.trace_id, "stepNumber": self.step_count, "thought": thoughts, "actions": [ { "function_name": action, "description": action_description, "parameters": action_args, } ], "image": screenshot_base64, "duration": step_duration, "inputTokensUsed": step_input_tokens, "outputTokensUsed": step_output_tokens, "timestamp": datetime.now().isoformat(), } # Send progress event await self.send_event( { "type": "agent_progress", "agentStep": step, "traceMetadata": self._get_metadata(), } ) if is_stop_action: final_answer = thoughts break # Send completion event final_state = "success" if is_stop_action else "max_steps_reached" await self.send_event( { "type": "agent_complete", "traceMetadata": self._get_metadata(completed=True), "final_state": final_state, } ) def _get_metadata(self, completed: bool = False) -> dict: """Get current trace metadata""" duration = 0 if self.start_time: duration = (datetime.now() - self.start_time).total_seconds() return { "traceId": self.trace_id, "inputTokensUsed": self.total_input_tokens, "outputTokensUsed": self.total_output_tokens, "duration": duration, "numberOfSteps": self.step_count, "maxSteps": 50, "completed": completed, } async def stop(self): """Request the agent to stop""" self.should_stop = True async def close(self): """Clean up resources""" if self.agent: try: await self.agent.close() except Exception as e: logger.error(f"Error closing agent: {e}") self.agent = None self.browser_manager = None if self.screenshots_dir and os.path.exists(self.screenshots_dir): import shutil try: shutil.rmtree(self.screenshots_dir) except Exception as e: logger.error(f"Error cleaning up screenshots: {e}") self.screenshots_dir = None @app.get("/api/models") async def get_models(): """Return available models""" return JSONResponse(content=AVAILABLE_MODELS) @app.post("/api/traces") async def store_trace(trace_data: dict): """ Store a task trace by forwarding to the Modal trace storage endpoint. This keeps Modal credentials on the server side. """ if not MODAL_TRACE_STORAGE_URL: logger.warning("Modal trace storage URL not configured") return JSONResponse( status_code=503, content={"success": False, "error": "Trace storage not configured"}, ) if not MODAL_TOKEN_ID or not MODAL_TOKEN_SECRET: logger.warning("Modal proxy auth credentials not configured") return JSONResponse( status_code=503, content={"success": False, "error": "Modal auth not configured"}, ) try: async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( MODAL_TRACE_STORAGE_URL, json=trace_data, headers={ "Content-Type": "application/json", "Modal-Key": MODAL_TOKEN_ID, "Modal-Secret": MODAL_TOKEN_SECRET, }, ) if response.status_code == 200: result = response.json() logger.info( f"Trace stored successfully: {result.get('trace_id', 'unknown')}" ) return JSONResponse(content=result) else: error_text = response.text logger.error( f"Failed to store trace: {response.status_code} - {error_text}" ) return JSONResponse( status_code=response.status_code, content={ "success": False, "error": f"Modal API error: {error_text}", }, ) except httpx.TimeoutException: logger.error("Timeout storing trace to Modal") return JSONResponse( status_code=504, content={"success": False, "error": "Timeout connecting to trace storage"}, ) except Exception as e: logger.exception("Error storing trace") return JSONResponse( status_code=500, content={"success": False, "error": str(e)} ) @app.get("/api/random-question") async def get_random_question(): """Return a random example question""" questions = [ "Search for the latest news about AI agents", "Find the weather forecast for San Francisco", "Go to GitHub and search for 'computer use agent'", "Find the top trending repositories on GitHub today", "Search for Python tutorials on YouTube", "Look up the current stock price of Microsoft", "Find the schedule for upcoming SpaceX launches", "Search for healthy breakfast recipes", ] import random return JSONResponse(content={"question": random.choice(questions)}) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for real-time communication""" await websocket.accept() # Generate a unique connection ID connection_id = str(uuid.uuid4()) active_connections[connection_id] = websocket # Send heartbeat with the connection ID (used as trace ID base) trace_id = str(uuid.uuid4()) await websocket.send_json( {"type": "heartbeat", "uuid": trace_id, "timestamp": datetime.now().isoformat()} ) try: while True: # Wait for messages from the client data = await websocket.receive_json() message_type = data.get("type") if message_type == "user_task": # Extract task details trace = data.get("trace", {}) trace_id = trace.get("id", str(uuid.uuid4())) instruction = trace.get("instruction", "") model_id = trace.get("modelId", "microsoft/Fara-7B") # Create and start session session = FaraSession(trace_id, websocket) active_sessions[trace_id] = session # Run the task in the background asyncio.create_task(session.run_task(instruction, model_id)) elif message_type == "stop_task": # Stop the running task trace_id = data.get("trace_id") if trace_id and trace_id in active_sessions: await active_sessions[trace_id].stop() elif message_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: logger.info(f"WebSocket disconnected: {connection_id}") except Exception as e: logger.exception(f"WebSocket error: {e}") finally: # Clean up if connection_id in active_connections: del active_connections[connection_id] # Clean up any sessions for this connection sessions_to_remove = [] for trace_id, session in active_sessions.items(): if session.websocket == websocket: await session.close() sessions_to_remove.append(trace_id) for trace_id in sessions_to_remove: del active_sessions[trace_id] @app.get("/api/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)