Fara-BrowserUse / backend /modal_fara_vllm.py
VyoJ's picture
Upload 78 files
7fcdb70 verified
import json
from typing import Any
import aiohttp
import modal
vllm_image = (
modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
.entrypoint([])
.uv_pip_install(
"vllm==0.11.2",
"huggingface-hub==0.36.0",
"flashinfer-python==0.5.2",
)
.env({"HF_XET_HIGH_PERFORMANCE": "1"}) # faster model transfers
)
# Lightweight image for the trace storage endpoint (doesn't need CUDA/vllm)
trace_storage_image = modal.Image.debian_slim(python_version="3.12").uv_pip_install(
"fastapi", "uvicorn"
)
MODEL_NAME = "microsoft/Fara-7B"
MODEL_REVISION = None # Use latest if no specific revision
hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
traces_vol = modal.Volume.from_name("fara-traces", create_if_missing=True)
FAST_BOOT = True # Set to False for better performance if not cold-starting frequently
app = modal.App("fara-vllm")
MINUTES = 60 # seconds
VLLM_PORT = 5000 # Changed to 5000 as per user specification
N_GPU = 1
@app.function(
image=vllm_image,
gpu="L40S",
scaledown_window=2 * MINUTES,
timeout=10 * MINUTES,
volumes={
"/root/.cache/huggingface": hf_cache_vol,
"/root/.cache/vllm": vllm_cache_vol,
},
)
@modal.concurrent(max_inputs=32)
@modal.web_server(
port=VLLM_PORT, startup_timeout=10 * MINUTES, requires_proxy_auth=True
)
def serve():
import subprocess
cmd = [
"vllm",
"serve",
"--uvicorn-log-level=info",
MODEL_NAME,
"--served-model-name",
MODEL_NAME,
"--host",
"0.0.0.0",
"--port",
str(VLLM_PORT),
"--dtype",
"auto", # As per user specification
"--max-model-len",
"32768", # Limit context length to fit in GPU memory (default 128000 too large)
]
if MODEL_REVISION:
cmd += ["--revision", MODEL_REVISION]
# enforce-eager disables both Torch compilation and CUDA graph capture
cmd += ["--enforce-eager" if FAST_BOOT else "--no-enforce-eager"]
# assume multiple GPUs are for splitting up large matrix multiplications
cmd += ["--tensor-parallel-size", str(N_GPU)]
print(cmd)
subprocess.Popen(" ".join(cmd), shell=True)
@app.function(
image=trace_storage_image,
volumes={"/traces": traces_vol},
timeout=2 * MINUTES,
)
@modal.fastapi_endpoint(method="POST", requires_proxy_auth=True)
def store_trace(trace_data: dict) -> dict:
"""
Store a task trace JSON in the Modal volume.
If a trace with the same ID and instruction already exists, it will be overwritten.
Expected trace_data structure:
{
"trace": { id, timestamp, instruction, modelId, isRunning },
"completion": { status, message, finalAnswer },
"metadata": { traceId, inputTokensUsed, outputTokensUsed, ... user_evaluation },
"steps": [...],
"exportedAt": "ISO timestamp"
}
"""
import glob
import os
from datetime import datetime
try:
# Extract trace ID and instruction for duplicate detection
trace_id = trace_data.get("trace", {}).get("id", "unknown")
instruction = trace_data.get("trace", {}).get("instruction", "")
# Create organized directory structure: /traces/YYYY-MM/
date_folder = datetime.now().strftime("%Y-%m")
trace_dir = f"/traces/{date_folder}"
os.makedirs(trace_dir, exist_ok=True)
# Check for existing trace with same ID (in all monthly folders)
existing_file = None
for monthly_dir in glob.glob("/traces/*/"):
for filepath in glob.glob(f"{monthly_dir}*_{trace_id}.json"):
# Found an existing file with this trace ID
# Verify it's the same trace by checking instruction
try:
with open(filepath, "r") as f:
existing_data = json.load(f)
existing_instruction = existing_data.get("trace", {}).get(
"instruction", ""
)
if existing_instruction == instruction:
existing_file = filepath
break
except (json.JSONDecodeError, IOError):
# If we can't read the file, skip it
continue
if existing_file:
break
if existing_file:
# Overwrite the existing file
filepath = existing_file
print(f"Overwriting existing trace: {filepath}")
else:
# Generate new filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}_{trace_id}.json"
filepath = f"{trace_dir}/{filename}"
# Write trace to file
with open(filepath, "w") as f:
json.dump(trace_data, f, indent=2, default=str)
# Commit volume changes
traces_vol.commit()
return {
"success": True,
"message": "Trace stored successfully"
if not existing_file
else "Trace updated successfully",
"filepath": filepath,
"trace_id": trace_id,
"was_overwritten": existing_file is not None,
}
except Exception as e:
return {
"success": False,
"error": str(e),
}
@app.local_entrypoint()
async def test(test_timeout=10 * MINUTES, content=None, twice=True):
url = serve.get_web_url()
system_prompt = {
"role": "system",
"content": "You are an AI assistant specialized in computer use tasks.",
}
if content is None:
content = "Hello, what can you do to help with computer tasks?"
messages = [ # OpenAI chat format
system_prompt,
{"role": "user", "content": content},
]
async with aiohttp.ClientSession(base_url=url) as session:
print(f"Running health check for server at {url}")
async with session.get("/health", timeout=test_timeout - 1 * MINUTES) as resp:
up = resp.status == 200
assert up, f"Failed health check for server at {url}"
print(f"Successful health check for server at {url}")
print(f"Sending messages to {url}:", *messages, sep="\n\t")
await _send_request(session, MODEL_NAME, messages)
if twice:
messages[0]["content"] = "You are a helpful assistant."
print(f"Sending messages to {url}:", *messages, sep="\n\t")
await _send_request(session, MODEL_NAME, messages)
async def _send_request(
session: aiohttp.ClientSession, model: str, messages: list
) -> None:
# \`stream=True\` tells an OpenAI-compatible backend to stream chunks
payload: dict[str, Any] = {"messages": messages, "model": model, "stream": True}
headers = {"Content-Type": "application/json", "Accept": "text/event-stream"}
async with session.post(
"/v1/chat/completions", json=payload, headers=headers, timeout=1 * MINUTES
) as resp:
async for raw in resp.content:
resp.raise_for_status()
# extract new content and stream it
line = raw.decode().strip()
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "): # SSE prefix
line = line[len("data: ") :]
chunk = json.loads(line)
assert (
chunk["object"] == "chat.completion.chunk"
) # or something went horribly wrong
print(chunk["choices"][0]["delta"]["content"], end="")
print()