# ============================================================ # app/core/llm_router.py - Smart LLM Routing with Fallbacks # ============================================================ import os import logging import time from typing import Optional, Dict, Any, List, Tuple from enum import Enum import re from openai import AsyncOpenAI from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from app.core.observability import trace_operation, get_token_tracker logger = logging.getLogger(__name__) # ============================================================ # LLM Definitions # ============================================================ class LLMModel(str, Enum): """Available LLM models""" # Primary (paid but reliable) DEEPSEEK_CHAT = "deepseek-chat" # Free OpenRouter fallbacks (ordered by preference) MISTRAL_FREE = "mistralai/mistral-7b-instruct" GROK_FREE = "xai-org/grok-beta" LLAMA_FREE = "meta-llama/llama-2-70b-chat" NEURAL_CHAT = "intel/neural-chat-7b" # ============================================================ # Model Configuration # ============================================================ MODEL_CONFIG = { LLMModel.DEEPSEEK_CHAT: { "api_key": os.getenv("DEEPSEEK_API_KEY"), "base_url": os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1"), "cost_per_1k_prompt": 0.014, # $0.14 per 1M tokens "cost_per_1k_completion": 0.014, "max_tokens": 4096, "timeout": 30, "tier": "primary", }, LLMModel.MISTRAL_FREE: { "api_key": os.getenv("OPENROUTER_API_KEY"), "base_url": "https://openrouter.ai/api/v1", "cost_per_1k_prompt": 0.0, # Free "cost_per_1k_completion": 0.0, "max_tokens": 2048, "timeout": 60, "tier": "fallback_1", "headers": { "HTTP-Referer": "https://lojiz.com", "X-Title": "Lojiz", } }, LLMModel.GROK_FREE: { "api_key": os.getenv("OPENROUTER_API_KEY"), "base_url": "https://openrouter.ai/api/v1", "cost_per_1k_prompt": 0.0, "cost_per_1k_completion": 0.0, "max_tokens": 2048, "timeout": 60, "tier": "fallback_2", "headers": { "HTTP-Referer": "https://lojiz.com", "X-Title": "Lojiz", } }, LLMModel.LLAMA_FREE: { "api_key": os.getenv("OPENROUTER_API_KEY"), "base_url": "https://openrouter.ai/api/v1", "cost_per_1k_prompt": 0.0, "cost_per_1k_completion": 0.0, "max_tokens": 2048, "timeout": 60, "tier": "fallback_3", "headers": { "HTTP-Referer": "https://lojiz.com", "X-Title": "Lojiz", } }, } # ============================================================ # Task Complexity Classification # ============================================================ class TaskComplexity(str, Enum): """Task complexity levels""" SIMPLE = "simple" # Routing, yes/no, basic extraction MEDIUM = "medium" # Field validation, basic reasoning COMPLEX = "complex" # Multi-step reasoning, generation def classify_task_complexity(prompt: str, intent: str = None) -> TaskComplexity: """Classify task complexity from prompt and intent""" # Simple tasks if intent in ["publish", "discard", "edit"]: return TaskComplexity.SIMPLE if any(w in prompt.lower() for w in ["yes", "no", "confirm", "publish"]): if len(prompt) < 50: return TaskComplexity.SIMPLE # Complex tasks if intent in ["generate", "summarize"]: return TaskComplexity.COMPLEX if len(prompt) > 500 or "generate" in prompt.lower() or "describe" in prompt.lower(): return TaskComplexity.COMPLEX # Medium by default return TaskComplexity.MEDIUM # ============================================================ # Model Selection Strategy # ============================================================ class LLMRouter: """Route requests to appropriate LLM based on task complexity and availability""" def __init__(self): self.clients: Dict[LLMModel, AsyncOpenAI] = {} self.model_status: Dict[LLMModel, bool] = {} self.call_count: Dict[LLMModel, int] = {} self.error_count: Dict[LLMModel, int] = {} self.token_tracker = get_token_tracker() self._init_clients() def _init_clients(self): """Initialize all LLM clients""" for model in LLMModel: try: config = MODEL_CONFIG.get(model) if not config or not config.get("api_key"): logger.warning(f"⚠️ No API key for {model}, skipping") self.model_status[model] = False continue client = AsyncOpenAI( api_key=config["api_key"], base_url=config["base_url"], timeout=config["timeout"], default_headers=config.get("headers"), ) self.clients[model] = client self.model_status[model] = True self.call_count[model] = 0 self.error_count[model] = 0 logger.info(f"✅ Initialized {model}") except Exception as e: logger.warning(f"⚠️ Failed to initialize {model}: {e}") self.model_status[model] = False def get_model_priority(self, complexity: TaskComplexity) -> List[LLMModel]: """Get models in priority order for task complexity""" if complexity == TaskComplexity.SIMPLE: # For simple tasks, try free models first return [ LLMModel.MISTRAL_FREE, LLMModel.GROK_FREE, LLMModel.DEEPSEEK_CHAT, # Fallback to paid if free fails ] elif complexity == TaskComplexity.MEDIUM: # For medium tasks, balance cost and quality return [ LLMModel.DEEPSEEK_CHAT, # Primary LLMModel.MISTRAL_FREE, LLMModel.GROK_FREE, ] else: # COMPLEX # For complex tasks, use best model return [ LLMModel.DEEPSEEK_CHAT, # Primary always LLMModel.LLAMA_FREE, # Fallback ] def get_next_available_model(self, complexity: TaskComplexity) -> Optional[LLMModel]: """Get next available model by priority""" priority = self.get_model_priority(complexity) for model in priority: if self.model_status.get(model, False): return model logger.error("❌ No available LLM models!") return None @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), retry=retry_if_exception_type((Exception,)) ) async def call_llm( self, messages: List[Dict[str, str]], complexity: TaskComplexity = TaskComplexity.MEDIUM, temperature: float = 0.0, max_tokens: int = 600, ) -> Tuple[str, LLMModel, Dict[str, int]]: """ Call LLM with automatic fallback Returns: (response_text, model_used, usage_stats) """ priority = self.get_model_priority(complexity) last_error = None for model in priority: if not self.model_status.get(model, False): logger.debug(f"⭐️ Skipping {model} (unavailable)") continue try: with trace_operation( f"llm_call.{model}", {"complexity": complexity, "max_tokens": max_tokens} ): logger.info(f"🤖 Calling {model} for {complexity} task") client = self.clients[model] config = MODEL_CONFIG[model] # Make API call start_time = time.time() response = await client.chat.completions.create( model=model.value, # ✅ FIXED: Use enum value instead of str(model) messages=messages, temperature=temperature, max_tokens=min(max_tokens, config["max_tokens"]), ) duration = time.time() - start_time # Extract response text = response.choices[0].message.content usage = { "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens, "duration_ms": int(duration * 1000), } # Calculate cost cost = self._calculate_cost(model, usage) # Track tokens self.token_tracker.record_tokens( str(model), usage["prompt_tokens"], usage["completion_tokens"], cost ) # Update stats self.call_count[model] = self.call_count.get(model, 0) + 1 self.error_count[model] = 0 # Reset errors on success logger.info( f"✅ {model} success | " f"tokens={usage['total_tokens']} | " f"cost=${cost:.4f} | " f"duration={duration:.2f}s" ) return text, model, usage except Exception as e: last_error = e self.error_count[model] = self.error_count.get(model, 0) + 1 error_rate = self.error_count[model] / max(self.call_count.get(model, 1), 1) logger.warning( f"⚠️ {model} failed (attempt): {str(e)[:100]} | " f"error_rate={error_rate:.2f}" ) # If error rate too high, mark model as unavailable if error_rate > 0.5: logger.error(f"❌ {model} error rate too high, marking unavailable") self.model_status[model] = False continue # All models failed logger.error(f"❌ All LLM models failed. Last error: {last_error}") raise RuntimeError(f"All LLM fallbacks exhausted: {last_error}") def _calculate_cost(self, model: LLMModel, usage: Dict[str, int]) -> float: """Calculate cost of LLM call""" config = MODEL_CONFIG[model] prompt_cost = (usage["prompt_tokens"] / 1000) * config["cost_per_1k_prompt"] completion_cost = (usage["completion_tokens"] / 1000) * config["cost_per_1k_completion"] return prompt_cost + completion_cost def get_stats(self) -> Dict[str, Any]: """Get router statistics""" return { "models": { str(model): { "available": self.model_status.get(model, False), "calls": self.call_count.get(model, 0), "errors": self.error_count.get(model, 0), } for model in LLMModel }, "total_calls": sum(self.call_count.values()), "total_errors": sum(self.error_count.values()), } # ============================================================ # Global Router Instance # ============================================================ _router = None def get_llm_router() -> LLMRouter: """Get or create global LLM router""" global _router if _router is None: _router = LLMRouter() return _router async def call_llm_smart( messages: List[Dict[str, str]], intent: str = None, temperature: float = 0.0, max_tokens: int = 600, ) -> Tuple[str, LLMModel, Dict[str, int]]: """Smart LLM call with complexity classification""" prompt = messages[-1]["content"] if messages else "" complexity = classify_task_complexity(prompt, intent) router = get_llm_router() return await router.call_llm( messages, complexity=complexity, temperature=temperature, max_tokens=max_tokens, )