edgellm / app.py
wu981526092
Deploy Edge LLM to Hugging Face Space
4d77f4f
raw
history blame
9.07 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Optional, Dict, Any
import os
app = FastAPI(title="Edge LLM API")
# Enable CORS for Hugging Face Space
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for HF Space
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
app.mount("/assets", StaticFiles(directory="static/assets"), name="assets")
# Available models
AVAILABLE_MODELS = {
"Qwen/Qwen3-4B-Thinking-2507": {
"name": "Qwen3-4B-Thinking-2507",
"supports_thinking": True,
"description": "Shows thinking process",
"size_gb": "~8GB"
},
"Qwen/Qwen3-4B-Instruct-2507": {
"name": "Qwen3-4B-Instruct-2507",
"supports_thinking": False,
"description": "Direct instruction following",
"size_gb": "~8GB"
}
}
# Global model cache
models_cache: Dict[str, Dict[str, Any]] = {}
current_model_name = None # No model loaded by default
class PromptRequest(BaseModel):
prompt: str
system_prompt: Optional[str] = None
model_name: Optional[str] = None
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = 1024
class PromptResponse(BaseModel):
thinking_content: str
content: str
model_used: str
supports_thinking: bool
class ModelInfo(BaseModel):
model_name: str
name: str
supports_thinking: bool
description: str
size_gb: str
is_loaded: bool
class ModelsResponse(BaseModel):
models: list[ModelInfo]
current_model: str
class ModelLoadRequest(BaseModel):
model_name: str
class ModelUnloadRequest(BaseModel):
model_name: str
def load_model_by_name(model_name: str):
"""Load a model into the cache"""
global models_cache
if model_name in models_cache:
return True
if model_name not in AVAILABLE_MODELS:
return False
try:
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
models_cache[model_name] = {
"model": model,
"tokenizer": tokenizer
}
print(f"Model {model_name} loaded successfully")
return True
except Exception as e:
print(f"Error loading model {model_name}: {e}")
return False
def unload_model_by_name(model_name: str):
"""Unload a model from the cache"""
global models_cache, current_model_name
if model_name in models_cache:
del models_cache[model_name]
if current_model_name == model_name:
current_model_name = None
print(f"Model {model_name} unloaded")
return True
return False
@app.on_event("startup")
async def startup_event():
"""Startup event - don't load models by default"""
print("🚀 Edge LLM API is starting up...")
print("💡 Models will be loaded on demand")
@app.get("/")
async def read_index():
"""Serve the React app"""
return FileResponse('static/index.html')
@app.get("/health")
async def health_check():
return {"status": "healthy", "message": "Edge LLM API is running"}
@app.get("/models", response_model=ModelsResponse)
async def get_models():
"""Get available models and their status"""
global current_model_name
models = []
for model_name, info in AVAILABLE_MODELS.items():
models.append(ModelInfo(
model_name=model_name,
name=info["name"],
supports_thinking=info["supports_thinking"],
description=info["description"],
size_gb=info["size_gb"],
is_loaded=model_name in models_cache
))
return ModelsResponse(
models=models,
current_model=current_model_name or ""
)
@app.post("/load-model")
async def load_model(request: ModelLoadRequest):
"""Load a specific model"""
global current_model_name
if request.model_name not in AVAILABLE_MODELS:
raise HTTPException(
status_code=400,
detail=f"Model {request.model_name} not available"
)
success = load_model_by_name(request.model_name)
if success:
current_model_name = request.model_name
return {
"message": f"Model {request.model_name} loaded successfully",
"current_model": current_model_name
}
else:
raise HTTPException(
status_code=500,
detail=f"Failed to load model {request.model_name}"
)
@app.post("/unload-model")
async def unload_model(request: ModelUnloadRequest):
"""Unload a specific model"""
global current_model_name
success = unload_model_by_name(request.model_name)
if success:
return {
"message": f"Model {request.model_name} unloaded successfully",
"current_model": current_model_name or ""
}
else:
raise HTTPException(
status_code=404,
detail=f"Model {request.model_name} not found in cache"
)
@app.post("/set-current-model")
async def set_current_model(request: ModelLoadRequest):
"""Set the current active model"""
global current_model_name
if request.model_name not in models_cache:
raise HTTPException(
status_code=400,
detail=f"Model {request.model_name} is not loaded. Please load it first."
)
current_model_name = request.model_name
return {
"message": f"Current model set to {current_model_name}",
"current_model": current_model_name
}
@app.post("/generate", response_model=PromptResponse)
async def generate_text(request: PromptRequest):
"""Generate text using the loaded model"""
global current_model_name
# Use the model specified in request, or fall back to current model
model_to_use = request.model_name if request.model_name else current_model_name
if not model_to_use:
raise HTTPException(
status_code=400,
detail="No model specified. Please load a model first."
)
if model_to_use not in models_cache:
raise HTTPException(
status_code=400,
detail=f"Model {model_to_use} is not loaded. Please load it first."
)
try:
model = models_cache[model_to_use]["model"]
tokenizer = models_cache[model_to_use]["tokenizer"]
model_info = AVAILABLE_MODELS[model_to_use]
# Build the prompt
messages = []
if request.system_prompt:
messages.append({"role": "system", "content": request.system_prompt})
messages.append({"role": "user", "content": request.prompt})
# Apply chat template
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "<thinking>" in generated_text:
parts = generated_text.split("<thinking>")
if len(parts) > 1:
thinking_part = parts[1]
if "</thinking>" in thinking_part:
thinking_content = thinking_part.split("</thinking>")[0].strip()
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else ""
final_content = remaining.strip()
return PromptResponse(
thinking_content=thinking_content,
content=final_content,
model_used=model_to_use,
supports_thinking=model_info["supports_thinking"]
)
except Exception as e:
print(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)