rag_chat_ / main.py
marcsixtysix's picture
Update main.py
cefd189 verified
import os
import json
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from datetime import datetime
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from starlette.concurrency import run_in_threadpool
import subprocess, sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.environ.get("DATA_DIR", "/app/data")
# os.makedirs(DATA_DIR, exist_ok=True)
OUTPUT_CHUNKS_FILE = os.path.join(DATA_DIR, "output_chunks.jsonl")
RAG_CONFIG_FILE = os.path.join(DATA_DIR, "rag_prompt_config.jsonl")
FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index")
EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy")
# Pydantic models for API
class ChatResponse(BaseModel):
response: str
timestamp: datetime
class ChatRequest(BaseModel):
query: str
history: list[dict] | None = None # optional conversation history
# Initialize Gemini API
# API key is configured during app startup (lifespan) to avoid import-time failures
# Lifespan function to handle startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup (no DB setup anymore)
print("Starting RAG Chat API (stateless, no database)...")
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
raise RuntimeError("Please set GEMINI_API_KEY environment variable")
genai.configure(api_key=API_KEY)
success, chunks_count = initialize_system()
if success:
print(f"✅ RAG system initialized with {chunks_count} chunks")
print("API ready at: http://localhost:8000 (docs at /docs)")
else:
raise RuntimeError("System initialization failed")
yield
print("Shutting down RAG Chat API...")
# Initialize FastAPI app with lifespan
app = FastAPI(
title="RAG Chat API",
description="RAG Chat System with Database Integration",
lifespan=lifespan,
)
"""API only module.
The web chat UI has been moved to a separate app (see web_app.py). This file now
exposes only the JSON API endpoints so it can be containerised independently or
scaled separately from the frontend.
"""
# Enable CORS for local dev
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost",
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:8001", # web UI container
"http://127.0.0.1:8001",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables to store precomputed data
chunks_data = None
chunk_embeddings = None
faiss_index = None
base_chunk = None
system_prompt = None
model_embedding = None
# Removed database session dependency (stateless mode)
def load_chunks(json_file):
"""Load chunks from JSON file"""
try:
with open(json_file, "r", encoding="utf-8") as file:
return json.load(file)
except FileNotFoundError:
raise FileNotFoundError(f"File {json_file} not found!")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON in {json_file}")
def compute_and_cache_embeddings(chunks):
"""Compute embeddings for all chunks and cache them"""
global chunk_embeddings, faiss_index
print("Computing embeddings for all chunks...")
texts = [chunk["content"] for chunk in chunks]
# Load or compute embeddings
if os.path.exists(EMBEDDINGS_FILE):
print("Loading cached embeddings...")
chunk_embeddings = np.load(EMBEDDINGS_FILE)
if chunk_embeddings.shape[0] != len(texts):
print("Cached embeddings count mismatches chunks. Recomputing...")
chunk_embeddings = model_embedding.encode(
texts, convert_to_numpy=True
).astype("float32")
np.save(EMBEDDINGS_FILE, chunk_embeddings)
else:
print("Computing new embeddings (this may take a moment)...")
chunk_embeddings = model_embedding.encode(texts, convert_to_numpy=True).astype(
"float32"
)
np.save(EMBEDDINGS_FILE, chunk_embeddings)
print("Embeddings cached for future use.")
# Normalize embeddings (for cosine similarity with IndexFlatIP)
faiss.normalize_L2(chunk_embeddings)
# Create or load FAISS index
embedding_dim = chunk_embeddings.shape[1]
if os.path.exists(FAISS_INDEX_FILE):
print("Loading cached FAISS index...")
faiss_index = faiss.read_index(FAISS_INDEX_FILE)
# Validate index matches embeddings
if getattr(faiss_index, "ntotal", 0) != chunk_embeddings.shape[0]:
print("FAISS index size mismatches embeddings. Rebuilding index...")
faiss_index = faiss.IndexFlatIP(embedding_dim)
faiss_index.add(chunk_embeddings)
faiss.write_index(faiss_index, FAISS_INDEX_FILE)
else:
print("Creating new FAISS index...")
faiss_index = faiss.IndexFlatIP(embedding_dim)
faiss_index.add(chunk_embeddings)
faiss.write_index(faiss_index, FAISS_INDEX_FILE)
print("FAISS index cached for future use.")
def retrieve_relevant_chunks(query, top_k=3):
"""Retrieve most relevant chunks for the query using precomputed embeddings"""
global chunks_data, faiss_index
if faiss_index is None or chunks_data is None:
raise RuntimeError("RAG index not initialized")
# Encode query
query_embedding = model_embedding.encode([query], convert_to_numpy=True).astype(
"float32"
)
faiss.normalize_L2(query_embedding)
top_k = min(top_k, len(chunks_data))
# Search in precomputed index
_, indices = faiss_index.search(query_embedding, top_k)
return [chunks_data[i] for i in indices[0]]
def _format_history(history: list[dict] | None, max_turns: int = 6) -> str:
"""Format recent conversation history for inclusion in the prompt."""
if not history:
return ""
recent = history[-max_turns:]
lines = []
for turn in recent:
role = turn.get("role", "user")
msg = (turn.get("message") or "").strip()
if not msg:
continue
prefix = "User" if role == "user" else "Assistant"
lines.append(f"{prefix}: {msg}")
return "\n".join(lines)
def construct_prompt(base_chunk, system_prompt, query, history_text: str = ""):
"""Construct the full prompt with relevant context"""
relevant_chunks = retrieve_relevant_chunks(query)
context = "\n\n".join(chunk["content"] for chunk in relevant_chunks)
full_prompt = (
f"System prompt:\n{system_prompt['content']}\n\n"
f"Context:\n{context}\n\n"
f"{base_chunk['content']}\n\n"
)
if history_text:
full_prompt += f"Recent conversation:\n{history_text}\n\n"
full_prompt += f"Query:\n{query}"
return full_prompt, context
def get_answer(prompt):
"""Get answer from Gemini API"""
try:
model = genai.GenerativeModel("gemini-2.5-flash")
response = model.generate_content(prompt)
return response.text
except Exception as e:
print(f"Error getting response from Gemini: {e}")
return None
def run_generate_rag_data():
"""Run the data generation script if available."""
script_path = os.path.join(SCRIPT_DIR, "generate_rag_data.py")
if not os.path.isfile(script_path):
print("generate_rag_data.py not found; skipping automatic generation.")
return
print("Running generate_rag_data.py to build RAG data...")
try:
subprocess.run([sys.executable, script_path], cwd=SCRIPT_DIR, check=True)
print("generate_rag_data.py completed.")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"generate_rag_data.py failed (exit {e.returncode})") from e
def initialize_system():
"""Initialize the RAG system with precomputed embeddings (stateless)."""
global chunks_data, base_chunk, system_prompt, model_embedding
try:
need_generation = (
not os.path.exists(EMBEDDINGS_FILE)
or not os.path.exists(OUTPUT_CHUNKS_FILE)
or not os.path.exists(RAG_CONFIG_FILE)
)
if need_generation:
print("RAG data or embeddings missing. Triggering data generation...")
run_generate_rag_data()
print("Loading embedding model...")
model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
print("Loading chunks and configuration...")
chunks_data = load_chunks(OUTPUT_CHUNKS_FILE)
config = load_chunks(RAG_CONFIG_FILE)[0]
base_chunk = config["base_chunk"]
system_prompt = config["system_prompt"]
print(f"Loaded {len(chunks_data)} chunks from knowledge base")
compute_and_cache_embeddings(chunks_data)
print("System initialized successfully (stateless mode)")
return True, len(chunks_data)
except Exception as e:
print(f"Failed to initialize system: {e}")
return False, 0
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(payload: ChatRequest):
"""Chat endpoint that processes queries (no persistence)."""
global base_chunk, system_prompt
query = (payload.query or "").strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
history_text = _format_history(payload.history)
full_prompt, _context = construct_prompt(
base_chunk, system_prompt, query, history_text
)
answer = await run_in_threadpool(get_answer, full_prompt)
if not answer:
answer = "Sorry, I failed to get a response from Gemini. Please try again."
return ChatResponse(response=answer, timestamp=datetime.utcnow())
# Simple health probe
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
if __name__ == "__main__":
import os
# Check environment variable
if not os.getenv("GEMINI_API_KEY"):
print("Warning: GEMINI_API_KEY environment variable not set!")
print("Please set it with: set GEMINI_API_KEY=your_api_key_here")
exit(1)
print("Starting RAG Chat API server...")
uvicorn.run("app", host="0.0.0.0", port=8000, reload=True, log_level="info")