Spaces:
Sleeping
Sleeping
File size: 10,469 Bytes
a840639 b4f3a10 a840639 f794808 a840639 cc5a41a 42c0706 cc5a41a a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 b4f3a10 a840639 653db3b a840639 653db3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
import os
import json
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from datetime import datetime
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from starlette.concurrency import run_in_threadpool
import subprocess, sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.environ.get("DATA_DIR", "/app/data")
# os.makedirs(DATA_DIR, exist_ok=True)
OUTPUT_CHUNKS_FILE = os.path.join(DATA_DIR, "output_chunks.jsonl")
RAG_CONFIG_FILE = os.path.join(DATA_DIR, "rag_prompt_config.jsonl")
FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index")
EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy")
# Pydantic models for API
class ChatResponse(BaseModel):
response: str
timestamp: datetime
class ChatRequest(BaseModel):
query: str
history: list[dict] | None = None # optional conversation history
# Initialize Gemini API
# API key is configured during app startup (lifespan) to avoid import-time failures
# Lifespan function to handle startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup (no DB setup anymore)
print("Starting RAG Chat API (stateless, no database)...")
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
raise RuntimeError("Please set GEMINI_API_KEY environment variable")
genai.configure(api_key=API_KEY)
success, chunks_count = initialize_system()
if success:
print(f"✅ RAG system initialized with {chunks_count} chunks")
print("API ready at: http://localhost:8000 (docs at /docs)")
else:
raise RuntimeError("System initialization failed")
yield
print("Shutting down RAG Chat API...")
# Initialize FastAPI app with lifespan
app = FastAPI(
title="RAG Chat API",
description="RAG Chat System with Database Integration",
lifespan=lifespan,
)
"""API only module.
The web chat UI has been moved to a separate app (see web_app.py). This file now
exposes only the JSON API endpoints so it can be containerised independently or
scaled separately from the frontend.
"""
# Enable CORS for local dev
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost",
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:8001", # web UI container
"http://127.0.0.1:8001",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables to store precomputed data
chunks_data = None
chunk_embeddings = None
faiss_index = None
base_chunk = None
system_prompt = None
model_embedding = None
# Removed database session dependency (stateless mode)
def load_chunks(json_file):
"""Load chunks from JSON file"""
try:
with open(json_file, "r", encoding="utf-8") as file:
return json.load(file)
except FileNotFoundError:
raise FileNotFoundError(f"File {json_file} not found!")
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON in {json_file}")
def compute_and_cache_embeddings(chunks):
"""Compute embeddings for all chunks and cache them"""
global chunk_embeddings, faiss_index
print("Computing embeddings for all chunks...")
texts = [chunk["content"] for chunk in chunks]
# Load or compute embeddings
if os.path.exists(EMBEDDINGS_FILE):
print("Loading cached embeddings...")
chunk_embeddings = np.load(EMBEDDINGS_FILE)
if chunk_embeddings.shape[0] != len(texts):
print("Cached embeddings count mismatches chunks. Recomputing...")
chunk_embeddings = model_embedding.encode(
texts, convert_to_numpy=True
).astype("float32")
np.save(EMBEDDINGS_FILE, chunk_embeddings)
else:
print("Computing new embeddings (this may take a moment)...")
chunk_embeddings = model_embedding.encode(texts, convert_to_numpy=True).astype(
"float32"
)
np.save(EMBEDDINGS_FILE, chunk_embeddings)
print("Embeddings cached for future use.")
# Normalize embeddings (for cosine similarity with IndexFlatIP)
faiss.normalize_L2(chunk_embeddings)
# Create or load FAISS index
embedding_dim = chunk_embeddings.shape[1]
if os.path.exists(FAISS_INDEX_FILE):
print("Loading cached FAISS index...")
faiss_index = faiss.read_index(FAISS_INDEX_FILE)
# Validate index matches embeddings
if getattr(faiss_index, "ntotal", 0) != chunk_embeddings.shape[0]:
print("FAISS index size mismatches embeddings. Rebuilding index...")
faiss_index = faiss.IndexFlatIP(embedding_dim)
faiss_index.add(chunk_embeddings)
faiss.write_index(faiss_index, FAISS_INDEX_FILE)
else:
print("Creating new FAISS index...")
faiss_index = faiss.IndexFlatIP(embedding_dim)
faiss_index.add(chunk_embeddings)
faiss.write_index(faiss_index, FAISS_INDEX_FILE)
print("FAISS index cached for future use.")
def retrieve_relevant_chunks(query, top_k=3):
"""Retrieve most relevant chunks for the query using precomputed embeddings"""
global chunks_data, faiss_index
if faiss_index is None or chunks_data is None:
raise RuntimeError("RAG index not initialized")
# Encode query
query_embedding = model_embedding.encode([query], convert_to_numpy=True).astype(
"float32"
)
faiss.normalize_L2(query_embedding)
top_k = min(top_k, len(chunks_data))
# Search in precomputed index
_, indices = faiss_index.search(query_embedding, top_k)
return [chunks_data[i] for i in indices[0]]
def _format_history(history: list[dict] | None, max_turns: int = 6) -> str:
"""Format recent conversation history for inclusion in the prompt."""
if not history:
return ""
recent = history[-max_turns:]
lines = []
for turn in recent:
role = turn.get("role", "user")
msg = (turn.get("message") or "").strip()
if not msg:
continue
prefix = "User" if role == "user" else "Assistant"
lines.append(f"{prefix}: {msg}")
return "\n".join(lines)
def construct_prompt(base_chunk, system_prompt, query, history_text: str = ""):
"""Construct the full prompt with relevant context"""
relevant_chunks = retrieve_relevant_chunks(query)
context = "\n\n".join(chunk["content"] for chunk in relevant_chunks)
full_prompt = (
f"System prompt:\n{system_prompt['content']}\n\n"
f"Context:\n{context}\n\n"
f"{base_chunk['content']}\n\n"
)
if history_text:
full_prompt += f"Recent conversation:\n{history_text}\n\n"
full_prompt += f"Query:\n{query}"
return full_prompt, context
def get_answer(prompt):
"""Get answer from Gemini API"""
try:
model = genai.GenerativeModel("gemini-2.5-flash")
response = model.generate_content(prompt)
return response.text
except Exception as e:
print(f"Error getting response from Gemini: {e}")
return None
def run_generate_rag_data():
"""Run the data generation script if available."""
script_path = os.path.join(SCRIPT_DIR, "generate_rag_data.py")
if not os.path.isfile(script_path):
print("generate_rag_data.py not found; skipping automatic generation.")
return
print("Running generate_rag_data.py to build RAG data...")
try:
subprocess.run([sys.executable, script_path], cwd=SCRIPT_DIR, check=True)
print("generate_rag_data.py completed.")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"generate_rag_data.py failed (exit {e.returncode})") from e
def initialize_system():
"""Initialize the RAG system with precomputed embeddings (stateless)."""
global chunks_data, base_chunk, system_prompt, model_embedding
try:
need_generation = (
not os.path.exists(EMBEDDINGS_FILE)
or not os.path.exists(OUTPUT_CHUNKS_FILE)
or not os.path.exists(RAG_CONFIG_FILE)
)
if need_generation:
print("RAG data or embeddings missing. Triggering data generation...")
run_generate_rag_data()
print("Loading embedding model...")
model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
print("Loading chunks and configuration...")
chunks_data = load_chunks(OUTPUT_CHUNKS_FILE)
config = load_chunks(RAG_CONFIG_FILE)[0]
base_chunk = config["base_chunk"]
system_prompt = config["system_prompt"]
print(f"Loaded {len(chunks_data)} chunks from knowledge base")
compute_and_cache_embeddings(chunks_data)
print("System initialized successfully (stateless mode)")
return True, len(chunks_data)
except Exception as e:
print(f"Failed to initialize system: {e}")
return False, 0
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(payload: ChatRequest):
"""Chat endpoint that processes queries (no persistence)."""
global base_chunk, system_prompt
query = (payload.query or "").strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
history_text = _format_history(payload.history)
full_prompt, _context = construct_prompt(
base_chunk, system_prompt, query, history_text
)
answer = await run_in_threadpool(get_answer, full_prompt)
if not answer:
answer = "Sorry, I failed to get a response from Gemini. Please try again."
return ChatResponse(response=answer, timestamp=datetime.utcnow())
# Simple health probe
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
if __name__ == "__main__":
import os
# Check environment variable
if not os.getenv("GEMINI_API_KEY"):
print("Warning: GEMINI_API_KEY environment variable not set!")
print("Please set it with: set GEMINI_API_KEY=your_api_key_here")
exit(1)
print("Starting RAG Chat API server...")
uvicorn.run("app", host="0.0.0.0", port=8000, reload=True, log_level="info")
|