|
|
import json
|
|
|
from typing import Any
|
|
|
|
|
|
import aiohttp
|
|
|
import modal
|
|
|
|
|
|
vllm_image = (
|
|
|
modal.Image.from_registry("nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.12")
|
|
|
.entrypoint([])
|
|
|
.uv_pip_install(
|
|
|
"vllm==0.11.2",
|
|
|
"huggingface-hub==0.36.0",
|
|
|
"flashinfer-python==0.5.2",
|
|
|
)
|
|
|
.env({"HF_XET_HIGH_PERFORMANCE": "1"})
|
|
|
)
|
|
|
|
|
|
|
|
|
trace_storage_image = modal.Image.debian_slim(python_version="3.12").uv_pip_install(
|
|
|
"fastapi", "uvicorn"
|
|
|
)
|
|
|
|
|
|
MODEL_NAME = "microsoft/Fara-7B"
|
|
|
MODEL_REVISION = None
|
|
|
|
|
|
hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
|
|
|
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
|
|
|
traces_vol = modal.Volume.from_name("fara-traces", create_if_missing=True)
|
|
|
|
|
|
FAST_BOOT = True
|
|
|
|
|
|
app = modal.App("fara-vllm")
|
|
|
|
|
|
MINUTES = 60
|
|
|
VLLM_PORT = 5000
|
|
|
N_GPU = 1
|
|
|
|
|
|
|
|
|
@app.function(
|
|
|
image=vllm_image,
|
|
|
gpu="L40S",
|
|
|
scaledown_window=2 * MINUTES,
|
|
|
timeout=10 * MINUTES,
|
|
|
volumes={
|
|
|
"/root/.cache/huggingface": hf_cache_vol,
|
|
|
"/root/.cache/vllm": vllm_cache_vol,
|
|
|
},
|
|
|
)
|
|
|
@modal.concurrent(max_inputs=32)
|
|
|
@modal.web_server(
|
|
|
port=VLLM_PORT, startup_timeout=10 * MINUTES, requires_proxy_auth=True
|
|
|
)
|
|
|
def serve():
|
|
|
import subprocess
|
|
|
|
|
|
cmd = [
|
|
|
"vllm",
|
|
|
"serve",
|
|
|
"--uvicorn-log-level=info",
|
|
|
MODEL_NAME,
|
|
|
"--served-model-name",
|
|
|
MODEL_NAME,
|
|
|
"--host",
|
|
|
"0.0.0.0",
|
|
|
"--port",
|
|
|
str(VLLM_PORT),
|
|
|
"--dtype",
|
|
|
"auto",
|
|
|
"--max-model-len",
|
|
|
"32768",
|
|
|
]
|
|
|
|
|
|
if MODEL_REVISION:
|
|
|
cmd += ["--revision", MODEL_REVISION]
|
|
|
|
|
|
|
|
|
cmd += ["--enforce-eager" if FAST_BOOT else "--no-enforce-eager"]
|
|
|
|
|
|
|
|
|
cmd += ["--tensor-parallel-size", str(N_GPU)]
|
|
|
|
|
|
print(cmd)
|
|
|
|
|
|
subprocess.Popen(" ".join(cmd), shell=True)
|
|
|
|
|
|
|
|
|
@app.function(
|
|
|
image=trace_storage_image,
|
|
|
volumes={"/traces": traces_vol},
|
|
|
timeout=2 * MINUTES,
|
|
|
)
|
|
|
@modal.fastapi_endpoint(method="POST", requires_proxy_auth=True)
|
|
|
def store_trace(trace_data: dict) -> dict:
|
|
|
"""
|
|
|
Store a task trace JSON in the Modal volume.
|
|
|
If a trace with the same ID and instruction already exists, it will be overwritten.
|
|
|
|
|
|
Expected trace_data structure:
|
|
|
{
|
|
|
"trace": { id, timestamp, instruction, modelId, isRunning },
|
|
|
"completion": { status, message, finalAnswer },
|
|
|
"metadata": { traceId, inputTokensUsed, outputTokensUsed, ... user_evaluation },
|
|
|
"steps": [...],
|
|
|
"exportedAt": "ISO timestamp"
|
|
|
}
|
|
|
"""
|
|
|
import glob
|
|
|
import os
|
|
|
from datetime import datetime
|
|
|
|
|
|
try:
|
|
|
|
|
|
trace_id = trace_data.get("trace", {}).get("id", "unknown")
|
|
|
instruction = trace_data.get("trace", {}).get("instruction", "")
|
|
|
|
|
|
|
|
|
date_folder = datetime.now().strftime("%Y-%m")
|
|
|
trace_dir = f"/traces/{date_folder}"
|
|
|
os.makedirs(trace_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
existing_file = None
|
|
|
for monthly_dir in glob.glob("/traces/*/"):
|
|
|
for filepath in glob.glob(f"{monthly_dir}*_{trace_id}.json"):
|
|
|
|
|
|
|
|
|
try:
|
|
|
with open(filepath, "r") as f:
|
|
|
existing_data = json.load(f)
|
|
|
existing_instruction = existing_data.get("trace", {}).get(
|
|
|
"instruction", ""
|
|
|
)
|
|
|
if existing_instruction == instruction:
|
|
|
existing_file = filepath
|
|
|
break
|
|
|
except (json.JSONDecodeError, IOError):
|
|
|
|
|
|
continue
|
|
|
if existing_file:
|
|
|
break
|
|
|
|
|
|
if existing_file:
|
|
|
|
|
|
filepath = existing_file
|
|
|
print(f"Overwriting existing trace: {filepath}")
|
|
|
else:
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
filename = f"{timestamp}_{trace_id}.json"
|
|
|
filepath = f"{trace_dir}/{filename}"
|
|
|
|
|
|
|
|
|
with open(filepath, "w") as f:
|
|
|
json.dump(trace_data, f, indent=2, default=str)
|
|
|
|
|
|
|
|
|
traces_vol.commit()
|
|
|
|
|
|
return {
|
|
|
"success": True,
|
|
|
"message": "Trace stored successfully"
|
|
|
if not existing_file
|
|
|
else "Trace updated successfully",
|
|
|
"filepath": filepath,
|
|
|
"trace_id": trace_id,
|
|
|
"was_overwritten": existing_file is not None,
|
|
|
}
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
"success": False,
|
|
|
"error": str(e),
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.local_entrypoint()
|
|
|
async def test(test_timeout=10 * MINUTES, content=None, twice=True):
|
|
|
url = serve.get_web_url()
|
|
|
|
|
|
system_prompt = {
|
|
|
"role": "system",
|
|
|
"content": "You are an AI assistant specialized in computer use tasks.",
|
|
|
}
|
|
|
if content is None:
|
|
|
content = "Hello, what can you do to help with computer tasks?"
|
|
|
|
|
|
messages = [
|
|
|
system_prompt,
|
|
|
{"role": "user", "content": content},
|
|
|
]
|
|
|
|
|
|
async with aiohttp.ClientSession(base_url=url) as session:
|
|
|
print(f"Running health check for server at {url}")
|
|
|
async with session.get("/health", timeout=test_timeout - 1 * MINUTES) as resp:
|
|
|
up = resp.status == 200
|
|
|
assert up, f"Failed health check for server at {url}"
|
|
|
print(f"Successful health check for server at {url}")
|
|
|
|
|
|
print(f"Sending messages to {url}:", *messages, sep="\n\t")
|
|
|
await _send_request(session, MODEL_NAME, messages)
|
|
|
if twice:
|
|
|
messages[0]["content"] = "You are a helpful assistant."
|
|
|
print(f"Sending messages to {url}:", *messages, sep="\n\t")
|
|
|
await _send_request(session, MODEL_NAME, messages)
|
|
|
|
|
|
|
|
|
async def _send_request(
|
|
|
session: aiohttp.ClientSession, model: str, messages: list
|
|
|
) -> None:
|
|
|
|
|
|
payload: dict[str, Any] = {"messages": messages, "model": model, "stream": True}
|
|
|
|
|
|
headers = {"Content-Type": "application/json", "Accept": "text/event-stream"}
|
|
|
|
|
|
async with session.post(
|
|
|
"/v1/chat/completions", json=payload, headers=headers, timeout=1 * MINUTES
|
|
|
) as resp:
|
|
|
async for raw in resp.content:
|
|
|
resp.raise_for_status()
|
|
|
|
|
|
line = raw.decode().strip()
|
|
|
if not line or line == "data: [DONE]":
|
|
|
continue
|
|
|
if line.startswith("data: "):
|
|
|
line = line[len("data: ") :]
|
|
|
|
|
|
chunk = json.loads(line)
|
|
|
assert (
|
|
|
chunk["object"] == "chat.completion.chunk"
|
|
|
)
|
|
|
print(chunk["choices"][0]["delta"]["content"], end="")
|
|
|
print()
|
|
|
|