# ============================================================ # app/core/context_manager.py - Context Window Management # ============================================================ import logging import json from typing import List, Dict, Optional, Tuple from datetime import datetime, timedelta import tiktoken from app.core.error_handling import LojizError logger = logging.getLogger(__name__) # ============================================================ # Token Counter # ============================================================ class TokenCounter: """Count tokens using tiktoken (OpenAI's tokenizer)""" def __init__(self, encoding_name: str = "cl100k_base"): try: self.encoding = tiktoken.get_encoding(encoding_name) except Exception as e: logger.warning(f"⚠️ Failed to load tiktoken: {e}, using fallback") self.encoding = None def count_tokens(self, text: str) -> int: """Count tokens in text""" if not self.encoding: # Fallback: rough estimate (4 chars ≈ 1 token) return len(text) // 4 return len(self.encoding.encode(text)) def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int: """Count tokens in message list""" total = 0 for msg in messages: # Add overhead per message (role + content markers) total += 4 if msg.get("role"): total += self.count_tokens(msg["role"]) if msg.get("content"): total += self.count_tokens(msg["content"]) # Add overhead for message framing total += 2 return total # ============================================================ # Context Manager # ============================================================ class ContextManager: """Manage context window to prevent overflow""" # Model limits (tokens) MODEL_LIMITS = { "deepseek-chat": 4096, "mistralai/mistral-7b-instruct": 8192, "xai-org/grok-beta": 8192, "meta-llama/llama-2-70b-chat": 4096, } # Reserve space for response RESPONSE_RESERVE = 600 def __init__(self, model: str = "deepseek-chat"): self.model = model self.token_counter = TokenCounter() self.context_limit = self.MODEL_LIMITS.get(model, 4096) self.usable_limit = self.context_limit - self.RESPONSE_RESERVE def get_available_context(self, current_tokens: int) -> int: """Get available context space""" return max(0, self.usable_limit - current_tokens) def is_context_full(self, messages: List[Dict[str, str]]) -> bool: """Check if context is full""" tokens = self.token_counter.count_messages_tokens(messages) return tokens >= self.usable_limit async def manage_context( self, messages: List[Dict[str, str]], max_history_messages: int = 20, ) -> List[Dict[str, str]]: """ Manage context by summarizing if needed Strategy: 1. Keep system message 2. Keep last message (current user input) 3. Summarize older messages if needed """ if not messages: return messages tokens = self.token_counter.count_messages_tokens(messages) if tokens <= self.usable_limit: logger.debug( f"✅ Context OK: {tokens}/{self.usable_limit} tokens, " f"{len(messages)} messages" ) return messages logger.warning( f"⚠️ Context overflow: {tokens}/{self.usable_limit} tokens, " f"{len(messages)} messages" ) # Keep system message + last message, summarize the rest system_msg = [m for m in messages if m.get("role") == "system"] user_msg = [m for m in messages if m.get("role") == "user"][-1:] if messages else [] history = [ m for m in messages if m.get("role") not in ["system"] and m not in user_msg ] # Trim history to most recent max_history_messages if len(history) > max_history_messages: logger.info(f"📦 Trimming history from {len(history)} to {max_history_messages}") history = history[-max_history_messages:] # Rebuild messages managed_messages = system_msg + history + user_msg final_tokens = self.token_counter.count_messages_tokens(managed_messages) logger.info( f"📦 Context managed: {final_tokens}/{self.usable_limit} tokens, " f"{len(managed_messages)} messages" ) return managed_messages async def summarize_conversation( self, messages: List[Dict[str, str]], summarizer_fn = None, ) -> str: """ Summarize conversation history Args: messages: Message history summarizer_fn: Optional async function to summarize Returns: Summary of conversation """ if not messages or len(messages) < 3: return "" # Extract conversation content (skip system message) conversation = [ m for m in messages if m.get("role") != "system" ] conversation_text = "\n".join([ f"{m.get('role', 'unknown').upper()}: {m.get('content', '')[:200]}" for m in conversation ]) # If no custom summarizer, use basic extraction if not summarizer_fn: return self._basic_summary(conversation) # Use custom summarizer try: summary = await summarizer_fn(conversation_text) return summary except Exception as e: logger.warning(f"⚠️ Summarization failed: {e}, using basic summary") return self._basic_summary(conversation) def _basic_summary(self, messages: List[Dict[str, str]]) -> str: """Basic summary extraction""" summaries = [] for msg in messages[-10:]: # Last 10 messages content = msg.get("content", "") if len(content) > 100: # Extract key points lines = content.split("\n") key_lines = [l for l in lines if len(l) > 20][:2] summaries.append(" ".join(key_lines)) else: summaries.append(content) return " | ".join(summaries) # ============================================================ # Message Window (sliding window) # ============================================================ class MessageWindow: """Sliding window for conversation history""" def __init__(self, window_size: int = 20, max_age_minutes: int = 120): self.window_size = window_size self.max_age = timedelta(minutes=max_age_minutes) self.messages: List[Dict[str, str]] = [] self.created_at = datetime.utcnow() def add_message(self, role: str, content: str) -> None: """Add message to window""" msg = { "role": role, "content": content, "timestamp": datetime.utcnow().isoformat(), } self.messages.append(msg) # Maintain window size if len(self.messages) > self.window_size: removed = self.messages.pop(0) logger.debug(f"📤 Removed old message from window") def get_messages(self, include_timestamps: bool = False) -> List[Dict[str, str]]: """Get messages in window""" messages = self.messages if not include_timestamps: # Remove timestamps for API calls messages = [ {k: v for k, v in m.items() if k != "timestamp"} for m in messages ] return messages def is_expired(self) -> bool: """Check if window has expired""" return datetime.utcnow() - self.created_at > self.max_age def clear(self) -> None: """Clear window""" self.messages = [] self.created_at = datetime.utcnow() def get_stats(self) -> Dict[str, int]: """Get window statistics""" return { "message_count": len(self.messages), "max_size": self.window_size, "age_seconds": int((datetime.utcnow() - self.created_at).total_seconds()), } # ============================================================ # Global Context Manager # ============================================================ _context_managers = {} _message_windows = {} def get_context_manager(model: str = "deepseek-chat") -> ContextManager: """Get or create context manager""" if model not in _context_managers: _context_managers[model] = ContextManager(model) return _context_managers[model] def get_message_window(user_id: str, create_if_missing: bool = True) -> Optional[MessageWindow]: """Get or create message window for user""" if user_id not in _message_windows: if create_if_missing: _message_windows[user_id] = MessageWindow() else: return None window = _message_windows[user_id] # Check if expired if window.is_expired(): logger.info(f"🗑️ Clearing expired window for user {user_id}") window.clear() return window def cleanup_expired_windows() -> int: """Clean up expired message windows""" expired = [ user_id for user_id, window in _message_windows.items() if window.is_expired() ] for user_id in expired: del _message_windows[user_id] if expired: logger.info(f"🧹 Cleaned up {len(expired)} expired windows") return len(expired)