|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
from typing import Optional, Dict, Any |
|
|
import os |
|
|
|
|
|
app = FastAPI(title="Edge LLM API") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
app.mount("/assets", StaticFiles(directory="static/assets"), name="assets") |
|
|
|
|
|
|
|
|
AVAILABLE_MODELS = { |
|
|
"Qwen/Qwen3-4B-Thinking-2507": { |
|
|
"name": "Qwen3-4B-Thinking-2507", |
|
|
"supports_thinking": True, |
|
|
"description": "Shows thinking process", |
|
|
"size_gb": "~8GB" |
|
|
}, |
|
|
"Qwen/Qwen3-4B-Instruct-2507": { |
|
|
"name": "Qwen3-4B-Instruct-2507", |
|
|
"supports_thinking": False, |
|
|
"description": "Direct instruction following", |
|
|
"size_gb": "~8GB" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
models_cache: Dict[str, Dict[str, Any]] = {} |
|
|
current_model_name = None |
|
|
|
|
|
class PromptRequest(BaseModel): |
|
|
prompt: str |
|
|
system_prompt: Optional[str] = None |
|
|
model_name: Optional[str] = None |
|
|
temperature: Optional[float] = 0.7 |
|
|
max_new_tokens: Optional[int] = 1024 |
|
|
|
|
|
class PromptResponse(BaseModel): |
|
|
thinking_content: str |
|
|
content: str |
|
|
model_used: str |
|
|
supports_thinking: bool |
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
model_name: str |
|
|
name: str |
|
|
supports_thinking: bool |
|
|
description: str |
|
|
size_gb: str |
|
|
is_loaded: bool |
|
|
|
|
|
class ModelsResponse(BaseModel): |
|
|
models: list[ModelInfo] |
|
|
current_model: str |
|
|
|
|
|
class ModelLoadRequest(BaseModel): |
|
|
model_name: str |
|
|
|
|
|
class ModelUnloadRequest(BaseModel): |
|
|
model_name: str |
|
|
|
|
|
def load_model_by_name(model_name: str): |
|
|
"""Load a model into the cache""" |
|
|
global models_cache |
|
|
|
|
|
if model_name in models_cache: |
|
|
return True |
|
|
|
|
|
if model_name not in AVAILABLE_MODELS: |
|
|
return False |
|
|
|
|
|
try: |
|
|
print(f"Loading model: {model_name}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
models_cache[model_name] = { |
|
|
"model": model, |
|
|
"tokenizer": tokenizer |
|
|
} |
|
|
print(f"Model {model_name} loaded successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error loading model {model_name}: {e}") |
|
|
return False |
|
|
|
|
|
def unload_model_by_name(model_name: str): |
|
|
"""Unload a model from the cache""" |
|
|
global models_cache, current_model_name |
|
|
|
|
|
if model_name in models_cache: |
|
|
del models_cache[model_name] |
|
|
if current_model_name == model_name: |
|
|
current_model_name = None |
|
|
print(f"Model {model_name} unloaded") |
|
|
return True |
|
|
return False |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Startup event - don't load models by default""" |
|
|
print("🚀 Edge LLM API is starting up...") |
|
|
print("💡 Models will be loaded on demand") |
|
|
|
|
|
@app.get("/") |
|
|
async def read_index(): |
|
|
"""Serve the React app""" |
|
|
return FileResponse('static/index.html') |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy", "message": "Edge LLM API is running"} |
|
|
|
|
|
@app.get("/models", response_model=ModelsResponse) |
|
|
async def get_models(): |
|
|
"""Get available models and their status""" |
|
|
global current_model_name |
|
|
|
|
|
models = [] |
|
|
for model_name, info in AVAILABLE_MODELS.items(): |
|
|
models.append(ModelInfo( |
|
|
model_name=model_name, |
|
|
name=info["name"], |
|
|
supports_thinking=info["supports_thinking"], |
|
|
description=info["description"], |
|
|
size_gb=info["size_gb"], |
|
|
is_loaded=model_name in models_cache |
|
|
)) |
|
|
|
|
|
return ModelsResponse( |
|
|
models=models, |
|
|
current_model=current_model_name or "" |
|
|
) |
|
|
|
|
|
@app.post("/load-model") |
|
|
async def load_model(request: ModelLoadRequest): |
|
|
"""Load a specific model""" |
|
|
global current_model_name |
|
|
|
|
|
if request.model_name not in AVAILABLE_MODELS: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {request.model_name} not available" |
|
|
) |
|
|
|
|
|
success = load_model_by_name(request.model_name) |
|
|
if success: |
|
|
current_model_name = request.model_name |
|
|
return { |
|
|
"message": f"Model {request.model_name} loaded successfully", |
|
|
"current_model": current_model_name |
|
|
} |
|
|
else: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Failed to load model {request.model_name}" |
|
|
) |
|
|
|
|
|
@app.post("/unload-model") |
|
|
async def unload_model(request: ModelUnloadRequest): |
|
|
"""Unload a specific model""" |
|
|
global current_model_name |
|
|
|
|
|
success = unload_model_by_name(request.model_name) |
|
|
if success: |
|
|
return { |
|
|
"message": f"Model {request.model_name} unloaded successfully", |
|
|
"current_model": current_model_name or "" |
|
|
} |
|
|
else: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model {request.model_name} not found in cache" |
|
|
) |
|
|
|
|
|
@app.post("/set-current-model") |
|
|
async def set_current_model(request: ModelLoadRequest): |
|
|
"""Set the current active model""" |
|
|
global current_model_name |
|
|
|
|
|
if request.model_name not in models_cache: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {request.model_name} is not loaded. Please load it first." |
|
|
) |
|
|
|
|
|
current_model_name = request.model_name |
|
|
return { |
|
|
"message": f"Current model set to {current_model_name}", |
|
|
"current_model": current_model_name |
|
|
} |
|
|
|
|
|
@app.post("/generate", response_model=PromptResponse) |
|
|
async def generate_text(request: PromptRequest): |
|
|
"""Generate text using the loaded model""" |
|
|
global current_model_name |
|
|
|
|
|
|
|
|
model_to_use = request.model_name if request.model_name else current_model_name |
|
|
|
|
|
if not model_to_use: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="No model specified. Please load a model first." |
|
|
) |
|
|
|
|
|
if model_to_use not in models_cache: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Model {model_to_use} is not loaded. Please load it first." |
|
|
) |
|
|
|
|
|
try: |
|
|
model = models_cache[model_to_use]["model"] |
|
|
tokenizer = models_cache[model_to_use]["tokenizer"] |
|
|
model_info = AVAILABLE_MODELS[model_to_use] |
|
|
|
|
|
|
|
|
messages = [] |
|
|
if request.system_prompt: |
|
|
messages.append({"role": "system", "content": request.system_prompt}) |
|
|
messages.append({"role": "user", "content": request.prompt}) |
|
|
|
|
|
|
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_new_tokens, |
|
|
temperature=request.temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:] |
|
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
thinking_content = "" |
|
|
final_content = generated_text |
|
|
|
|
|
if model_info["supports_thinking"] and "<thinking>" in generated_text: |
|
|
parts = generated_text.split("<thinking>") |
|
|
if len(parts) > 1: |
|
|
thinking_part = parts[1] |
|
|
if "</thinking>" in thinking_part: |
|
|
thinking_content = thinking_part.split("</thinking>")[0].strip() |
|
|
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else "" |
|
|
final_content = remaining.strip() |
|
|
|
|
|
return PromptResponse( |
|
|
thinking_content=thinking_content, |
|
|
content=final_content, |
|
|
model_used=model_to_use, |
|
|
supports_thinking=model_info["supports_thinking"] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|