import os import json import google.generativeai as genai from sentence_transformers import SentenceTransformer import faiss import numpy as np from datetime import datetime from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn from starlette.concurrency import run_in_threadpool import subprocess, sys SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.environ.get("DATA_DIR", "/app/data") # os.makedirs(DATA_DIR, exist_ok=True) OUTPUT_CHUNKS_FILE = os.path.join(DATA_DIR, "output_chunks.jsonl") RAG_CONFIG_FILE = os.path.join(DATA_DIR, "rag_prompt_config.jsonl") FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index") EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy") # Pydantic models for API class ChatResponse(BaseModel): response: str timestamp: datetime class ChatRequest(BaseModel): query: str history: list[dict] | None = None # optional conversation history # Initialize Gemini API # API key is configured during app startup (lifespan) to avoid import-time failures # Lifespan function to handle startup and shutdown @asynccontextmanager async def lifespan(app: FastAPI): # Startup (no DB setup anymore) print("Starting RAG Chat API (stateless, no database)...") API_KEY = os.getenv("GEMINI_API_KEY") if not API_KEY: raise RuntimeError("Please set GEMINI_API_KEY environment variable") genai.configure(api_key=API_KEY) success, chunks_count = initialize_system() if success: print(f"✅ RAG system initialized with {chunks_count} chunks") print("API ready at: http://localhost:8000 (docs at /docs)") else: raise RuntimeError("System initialization failed") yield print("Shutting down RAG Chat API...") # Initialize FastAPI app with lifespan app = FastAPI( title="RAG Chat API", description="RAG Chat System with Database Integration", lifespan=lifespan, ) """API only module. The web chat UI has been moved to a separate app (see web_app.py). This file now exposes only the JSON API endpoints so it can be containerised independently or scaled separately from the frontend. """ # Enable CORS for local dev app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost", "http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:8001", # web UI container "http://127.0.0.1:8001", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables to store precomputed data chunks_data = None chunk_embeddings = None faiss_index = None base_chunk = None system_prompt = None model_embedding = None # Removed database session dependency (stateless mode) def load_chunks(json_file): """Load chunks from JSON file""" try: with open(json_file, "r", encoding="utf-8") as file: return json.load(file) except FileNotFoundError: raise FileNotFoundError(f"File {json_file} not found!") except json.JSONDecodeError: raise ValueError(f"Invalid JSON in {json_file}") def compute_and_cache_embeddings(chunks): """Compute embeddings for all chunks and cache them""" global chunk_embeddings, faiss_index print("Computing embeddings for all chunks...") texts = [chunk["content"] for chunk in chunks] # Load or compute embeddings if os.path.exists(EMBEDDINGS_FILE): print("Loading cached embeddings...") chunk_embeddings = np.load(EMBEDDINGS_FILE) if chunk_embeddings.shape[0] != len(texts): print("Cached embeddings count mismatches chunks. Recomputing...") chunk_embeddings = model_embedding.encode( texts, convert_to_numpy=True ).astype("float32") np.save(EMBEDDINGS_FILE, chunk_embeddings) else: print("Computing new embeddings (this may take a moment)...") chunk_embeddings = model_embedding.encode(texts, convert_to_numpy=True).astype( "float32" ) np.save(EMBEDDINGS_FILE, chunk_embeddings) print("Embeddings cached for future use.") # Normalize embeddings (for cosine similarity with IndexFlatIP) faiss.normalize_L2(chunk_embeddings) # Create or load FAISS index embedding_dim = chunk_embeddings.shape[1] if os.path.exists(FAISS_INDEX_FILE): print("Loading cached FAISS index...") faiss_index = faiss.read_index(FAISS_INDEX_FILE) # Validate index matches embeddings if getattr(faiss_index, "ntotal", 0) != chunk_embeddings.shape[0]: print("FAISS index size mismatches embeddings. Rebuilding index...") faiss_index = faiss.IndexFlatIP(embedding_dim) faiss_index.add(chunk_embeddings) faiss.write_index(faiss_index, FAISS_INDEX_FILE) else: print("Creating new FAISS index...") faiss_index = faiss.IndexFlatIP(embedding_dim) faiss_index.add(chunk_embeddings) faiss.write_index(faiss_index, FAISS_INDEX_FILE) print("FAISS index cached for future use.") def retrieve_relevant_chunks(query, top_k=3): """Retrieve most relevant chunks for the query using precomputed embeddings""" global chunks_data, faiss_index if faiss_index is None or chunks_data is None: raise RuntimeError("RAG index not initialized") # Encode query query_embedding = model_embedding.encode([query], convert_to_numpy=True).astype( "float32" ) faiss.normalize_L2(query_embedding) top_k = min(top_k, len(chunks_data)) # Search in precomputed index _, indices = faiss_index.search(query_embedding, top_k) return [chunks_data[i] for i in indices[0]] def _format_history(history: list[dict] | None, max_turns: int = 6) -> str: """Format recent conversation history for inclusion in the prompt.""" if not history: return "" recent = history[-max_turns:] lines = [] for turn in recent: role = turn.get("role", "user") msg = (turn.get("message") or "").strip() if not msg: continue prefix = "User" if role == "user" else "Assistant" lines.append(f"{prefix}: {msg}") return "\n".join(lines) def construct_prompt(base_chunk, system_prompt, query, history_text: str = ""): """Construct the full prompt with relevant context""" relevant_chunks = retrieve_relevant_chunks(query) context = "\n\n".join(chunk["content"] for chunk in relevant_chunks) full_prompt = ( f"System prompt:\n{system_prompt['content']}\n\n" f"Context:\n{context}\n\n" f"{base_chunk['content']}\n\n" ) if history_text: full_prompt += f"Recent conversation:\n{history_text}\n\n" full_prompt += f"Query:\n{query}" return full_prompt, context def get_answer(prompt): """Get answer from Gemini API""" try: model = genai.GenerativeModel("gemini-2.5-flash") response = model.generate_content(prompt) return response.text except Exception as e: print(f"Error getting response from Gemini: {e}") return None def run_generate_rag_data(): """Run the data generation script if available.""" script_path = os.path.join(SCRIPT_DIR, "generate_rag_data.py") if not os.path.isfile(script_path): print("generate_rag_data.py not found; skipping automatic generation.") return print("Running generate_rag_data.py to build RAG data...") try: subprocess.run([sys.executable, script_path], cwd=SCRIPT_DIR, check=True) print("generate_rag_data.py completed.") except subprocess.CalledProcessError as e: raise RuntimeError(f"generate_rag_data.py failed (exit {e.returncode})") from e def initialize_system(): """Initialize the RAG system with precomputed embeddings (stateless).""" global chunks_data, base_chunk, system_prompt, model_embedding try: need_generation = ( not os.path.exists(EMBEDDINGS_FILE) or not os.path.exists(OUTPUT_CHUNKS_FILE) or not os.path.exists(RAG_CONFIG_FILE) ) if need_generation: print("RAG data or embeddings missing. Triggering data generation...") run_generate_rag_data() print("Loading embedding model...") model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B") print("Loading chunks and configuration...") chunks_data = load_chunks(OUTPUT_CHUNKS_FILE) config = load_chunks(RAG_CONFIG_FILE)[0] base_chunk = config["base_chunk"] system_prompt = config["system_prompt"] print(f"Loaded {len(chunks_data)} chunks from knowledge base") compute_and_cache_embeddings(chunks_data) print("System initialized successfully (stateless mode)") return True, len(chunks_data) except Exception as e: print(f"Failed to initialize system: {e}") return False, 0 @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(payload: ChatRequest): """Chat endpoint that processes queries (no persistence).""" global base_chunk, system_prompt query = (payload.query or "").strip() if not query: raise HTTPException(status_code=400, detail="Query cannot be empty") history_text = _format_history(payload.history) full_prompt, _context = construct_prompt( base_chunk, system_prompt, query, history_text ) answer = await run_in_threadpool(get_answer, full_prompt) if not answer: answer = "Sorry, I failed to get a response from Gemini. Please try again." return ChatResponse(response=answer, timestamp=datetime.utcnow()) # Simple health probe @app.get("/health") def health(): return {"status": "ok"} @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") if __name__ == "__main__": import os # Check environment variable if not os.getenv("GEMINI_API_KEY"): print("Warning: GEMINI_API_KEY environment variable not set!") print("Please set it with: set GEMINI_API_KEY=your_api_key_here") exit(1) print("Starting RAG Chat API server...") uvicorn.run("app", host="0.0.0.0", port=8000, reload=True, log_level="info")