from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch from typing import Optional, Dict, Any import os app = FastAPI(title="Edge LLM API") # Enable CORS for Hugging Face Space app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins for HF Space allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files app.mount("/assets", StaticFiles(directory="static/assets"), name="assets") # Available models AVAILABLE_MODELS = { "Qwen/Qwen3-4B-Thinking-2507": { "name": "Qwen3-4B-Thinking-2507", "supports_thinking": True, "description": "Shows thinking process", "size_gb": "~8GB" }, "Qwen/Qwen3-4B-Instruct-2507": { "name": "Qwen3-4B-Instruct-2507", "supports_thinking": False, "description": "Direct instruction following", "size_gb": "~8GB" } } # Global model cache models_cache: Dict[str, Dict[str, Any]] = {} current_model_name = None # No model loaded by default class PromptRequest(BaseModel): prompt: str system_prompt: Optional[str] = None model_name: Optional[str] = None temperature: Optional[float] = 0.7 max_new_tokens: Optional[int] = 1024 class PromptResponse(BaseModel): thinking_content: str content: str model_used: str supports_thinking: bool class ModelInfo(BaseModel): model_name: str name: str supports_thinking: bool description: str size_gb: str is_loaded: bool class ModelsResponse(BaseModel): models: list[ModelInfo] current_model: str class ModelLoadRequest(BaseModel): model_name: str class ModelUnloadRequest(BaseModel): model_name: str def load_model_by_name(model_name: str): """Load a model into the cache""" global models_cache if model_name in models_cache: return True if model_name not in AVAILABLE_MODELS: return False try: print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) models_cache[model_name] = { "model": model, "tokenizer": tokenizer } print(f"Model {model_name} loaded successfully") return True except Exception as e: print(f"Error loading model {model_name}: {e}") return False def unload_model_by_name(model_name: str): """Unload a model from the cache""" global models_cache, current_model_name if model_name in models_cache: del models_cache[model_name] if current_model_name == model_name: current_model_name = None print(f"Model {model_name} unloaded") return True return False @app.on_event("startup") async def startup_event(): """Startup event - don't load models by default""" print("🚀 Edge LLM API is starting up...") print("💡 Models will be loaded on demand") @app.get("/") async def read_index(): """Serve the React app""" return FileResponse('static/index.html') @app.get("/health") async def health_check(): return {"status": "healthy", "message": "Edge LLM API is running"} @app.get("/models", response_model=ModelsResponse) async def get_models(): """Get available models and their status""" global current_model_name models = [] for model_name, info in AVAILABLE_MODELS.items(): models.append(ModelInfo( model_name=model_name, name=info["name"], supports_thinking=info["supports_thinking"], description=info["description"], size_gb=info["size_gb"], is_loaded=model_name in models_cache )) return ModelsResponse( models=models, current_model=current_model_name or "" ) @app.post("/load-model") async def load_model(request: ModelLoadRequest): """Load a specific model""" global current_model_name if request.model_name not in AVAILABLE_MODELS: raise HTTPException( status_code=400, detail=f"Model {request.model_name} not available" ) success = load_model_by_name(request.model_name) if success: current_model_name = request.model_name return { "message": f"Model {request.model_name} loaded successfully", "current_model": current_model_name } else: raise HTTPException( status_code=500, detail=f"Failed to load model {request.model_name}" ) @app.post("/unload-model") async def unload_model(request: ModelUnloadRequest): """Unload a specific model""" global current_model_name success = unload_model_by_name(request.model_name) if success: return { "message": f"Model {request.model_name} unloaded successfully", "current_model": current_model_name or "" } else: raise HTTPException( status_code=404, detail=f"Model {request.model_name} not found in cache" ) @app.post("/set-current-model") async def set_current_model(request: ModelLoadRequest): """Set the current active model""" global current_model_name if request.model_name not in models_cache: raise HTTPException( status_code=400, detail=f"Model {request.model_name} is not loaded. Please load it first." ) current_model_name = request.model_name return { "message": f"Current model set to {current_model_name}", "current_model": current_model_name } @app.post("/generate", response_model=PromptResponse) async def generate_text(request: PromptRequest): """Generate text using the loaded model""" global current_model_name # Use the model specified in request, or fall back to current model model_to_use = request.model_name if request.model_name else current_model_name if not model_to_use: raise HTTPException( status_code=400, detail="No model specified. Please load a model first." ) if model_to_use not in models_cache: raise HTTPException( status_code=400, detail=f"Model {model_to_use} is not loaded. Please load it first." ) try: model = models_cache[model_to_use]["model"] tokenizer = models_cache[model_to_use]["tokenizer"] model_info = AVAILABLE_MODELS[model_to_use] # Build the prompt messages = [] if request.system_prompt: messages.append({"role": "system", "content": request.system_prompt}) messages.append({"role": "user", "content": request.prompt}) # Apply chat template formatted_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_new_tokens, temperature=request.temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode generated_tokens = outputs[0][inputs['input_ids'].shape[1]:] generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) # Parse thinking vs final content for thinking models thinking_content = "" final_content = generated_text if model_info["supports_thinking"] and "" in generated_text: parts = generated_text.split("") if len(parts) > 1: thinking_part = parts[1] if "" in thinking_part: thinking_content = thinking_part.split("")[0].strip() remaining = thinking_part.split("", 1)[1] if "" in thinking_part else "" final_content = remaining.strip() return PromptResponse( thinking_content=thinking_content, content=final_content, model_used=model_to_use, supports_thinking=model_info["supports_thinking"] ) except Exception as e: print(f"Generation error: {e}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)