Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import google.generativeai as genai | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from datetime import datetime | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from starlette.concurrency import run_in_threadpool | |
| import subprocess, sys | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_DIR = os.environ.get("DATA_DIR", "/app/data") | |
| # os.makedirs(DATA_DIR, exist_ok=True) | |
| OUTPUT_CHUNKS_FILE = os.path.join(DATA_DIR, "output_chunks.jsonl") | |
| RAG_CONFIG_FILE = os.path.join(DATA_DIR, "rag_prompt_config.jsonl") | |
| FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index") | |
| EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy") | |
| # Pydantic models for API | |
| class ChatResponse(BaseModel): | |
| response: str | |
| timestamp: datetime | |
| class ChatRequest(BaseModel): | |
| query: str | |
| history: list[dict] | None = None # optional conversation history | |
| # Initialize Gemini API | |
| # API key is configured during app startup (lifespan) to avoid import-time failures | |
| # Lifespan function to handle startup and shutdown | |
| async def lifespan(app: FastAPI): | |
| # Startup (no DB setup anymore) | |
| print("Starting RAG Chat API (stateless, no database)...") | |
| API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not API_KEY: | |
| raise RuntimeError("Please set GEMINI_API_KEY environment variable") | |
| genai.configure(api_key=API_KEY) | |
| success, chunks_count = initialize_system() | |
| if success: | |
| print(f"✅ RAG system initialized with {chunks_count} chunks") | |
| print("API ready at: http://localhost:8000 (docs at /docs)") | |
| else: | |
| raise RuntimeError("System initialization failed") | |
| yield | |
| print("Shutting down RAG Chat API...") | |
| # Initialize FastAPI app with lifespan | |
| app = FastAPI( | |
| title="RAG Chat API", | |
| description="RAG Chat System with Database Integration", | |
| lifespan=lifespan, | |
| ) | |
| """API only module. | |
| The web chat UI has been moved to a separate app (see web_app.py). This file now | |
| exposes only the JSON API endpoints so it can be containerised independently or | |
| scaled separately from the frontend. | |
| """ | |
| # Enable CORS for local dev | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost", | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "http://localhost:8001", # web UI container | |
| "http://127.0.0.1:8001", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables to store precomputed data | |
| chunks_data = None | |
| chunk_embeddings = None | |
| faiss_index = None | |
| base_chunk = None | |
| system_prompt = None | |
| model_embedding = None | |
| # Removed database session dependency (stateless mode) | |
| def load_chunks(json_file): | |
| """Load chunks from JSON file""" | |
| try: | |
| with open(json_file, "r", encoding="utf-8") as file: | |
| return json.load(file) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"File {json_file} not found!") | |
| except json.JSONDecodeError: | |
| raise ValueError(f"Invalid JSON in {json_file}") | |
| def compute_and_cache_embeddings(chunks): | |
| """Compute embeddings for all chunks and cache them""" | |
| global chunk_embeddings, faiss_index | |
| print("Computing embeddings for all chunks...") | |
| texts = [chunk["content"] for chunk in chunks] | |
| # Load or compute embeddings | |
| if os.path.exists(EMBEDDINGS_FILE): | |
| print("Loading cached embeddings...") | |
| chunk_embeddings = np.load(EMBEDDINGS_FILE) | |
| if chunk_embeddings.shape[0] != len(texts): | |
| print("Cached embeddings count mismatches chunks. Recomputing...") | |
| chunk_embeddings = model_embedding.encode( | |
| texts, convert_to_numpy=True | |
| ).astype("float32") | |
| np.save(EMBEDDINGS_FILE, chunk_embeddings) | |
| else: | |
| print("Computing new embeddings (this may take a moment)...") | |
| chunk_embeddings = model_embedding.encode(texts, convert_to_numpy=True).astype( | |
| "float32" | |
| ) | |
| np.save(EMBEDDINGS_FILE, chunk_embeddings) | |
| print("Embeddings cached for future use.") | |
| # Normalize embeddings (for cosine similarity with IndexFlatIP) | |
| faiss.normalize_L2(chunk_embeddings) | |
| # Create or load FAISS index | |
| embedding_dim = chunk_embeddings.shape[1] | |
| if os.path.exists(FAISS_INDEX_FILE): | |
| print("Loading cached FAISS index...") | |
| faiss_index = faiss.read_index(FAISS_INDEX_FILE) | |
| # Validate index matches embeddings | |
| if getattr(faiss_index, "ntotal", 0) != chunk_embeddings.shape[0]: | |
| print("FAISS index size mismatches embeddings. Rebuilding index...") | |
| faiss_index = faiss.IndexFlatIP(embedding_dim) | |
| faiss_index.add(chunk_embeddings) | |
| faiss.write_index(faiss_index, FAISS_INDEX_FILE) | |
| else: | |
| print("Creating new FAISS index...") | |
| faiss_index = faiss.IndexFlatIP(embedding_dim) | |
| faiss_index.add(chunk_embeddings) | |
| faiss.write_index(faiss_index, FAISS_INDEX_FILE) | |
| print("FAISS index cached for future use.") | |
| def retrieve_relevant_chunks(query, top_k=3): | |
| """Retrieve most relevant chunks for the query using precomputed embeddings""" | |
| global chunks_data, faiss_index | |
| if faiss_index is None or chunks_data is None: | |
| raise RuntimeError("RAG index not initialized") | |
| # Encode query | |
| query_embedding = model_embedding.encode([query], convert_to_numpy=True).astype( | |
| "float32" | |
| ) | |
| faiss.normalize_L2(query_embedding) | |
| top_k = min(top_k, len(chunks_data)) | |
| # Search in precomputed index | |
| _, indices = faiss_index.search(query_embedding, top_k) | |
| return [chunks_data[i] for i in indices[0]] | |
| def _format_history(history: list[dict] | None, max_turns: int = 6) -> str: | |
| """Format recent conversation history for inclusion in the prompt.""" | |
| if not history: | |
| return "" | |
| recent = history[-max_turns:] | |
| lines = [] | |
| for turn in recent: | |
| role = turn.get("role", "user") | |
| msg = (turn.get("message") or "").strip() | |
| if not msg: | |
| continue | |
| prefix = "User" if role == "user" else "Assistant" | |
| lines.append(f"{prefix}: {msg}") | |
| return "\n".join(lines) | |
| def construct_prompt(base_chunk, system_prompt, query, history_text: str = ""): | |
| """Construct the full prompt with relevant context""" | |
| relevant_chunks = retrieve_relevant_chunks(query) | |
| context = "\n\n".join(chunk["content"] for chunk in relevant_chunks) | |
| full_prompt = ( | |
| f"System prompt:\n{system_prompt['content']}\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"{base_chunk['content']}\n\n" | |
| ) | |
| if history_text: | |
| full_prompt += f"Recent conversation:\n{history_text}\n\n" | |
| full_prompt += f"Query:\n{query}" | |
| return full_prompt, context | |
| def get_answer(prompt): | |
| """Get answer from Gemini API""" | |
| try: | |
| model = genai.GenerativeModel("gemini-2.5-flash") | |
| response = model.generate_content(prompt) | |
| return response.text | |
| except Exception as e: | |
| print(f"Error getting response from Gemini: {e}") | |
| return None | |
| def run_generate_rag_data(): | |
| """Run the data generation script if available.""" | |
| script_path = os.path.join(SCRIPT_DIR, "generate_rag_data.py") | |
| if not os.path.isfile(script_path): | |
| print("generate_rag_data.py not found; skipping automatic generation.") | |
| return | |
| print("Running generate_rag_data.py to build RAG data...") | |
| try: | |
| subprocess.run([sys.executable, script_path], cwd=SCRIPT_DIR, check=True) | |
| print("generate_rag_data.py completed.") | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"generate_rag_data.py failed (exit {e.returncode})") from e | |
| def initialize_system(): | |
| """Initialize the RAG system with precomputed embeddings (stateless).""" | |
| global chunks_data, base_chunk, system_prompt, model_embedding | |
| try: | |
| need_generation = ( | |
| not os.path.exists(EMBEDDINGS_FILE) | |
| or not os.path.exists(OUTPUT_CHUNKS_FILE) | |
| or not os.path.exists(RAG_CONFIG_FILE) | |
| ) | |
| if need_generation: | |
| print("RAG data or embeddings missing. Triggering data generation...") | |
| run_generate_rag_data() | |
| print("Loading embedding model...") | |
| model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B") | |
| print("Loading chunks and configuration...") | |
| chunks_data = load_chunks(OUTPUT_CHUNKS_FILE) | |
| config = load_chunks(RAG_CONFIG_FILE)[0] | |
| base_chunk = config["base_chunk"] | |
| system_prompt = config["system_prompt"] | |
| print(f"Loaded {len(chunks_data)} chunks from knowledge base") | |
| compute_and_cache_embeddings(chunks_data) | |
| print("System initialized successfully (stateless mode)") | |
| return True, len(chunks_data) | |
| except Exception as e: | |
| print(f"Failed to initialize system: {e}") | |
| return False, 0 | |
| async def chat_endpoint(payload: ChatRequest): | |
| """Chat endpoint that processes queries (no persistence).""" | |
| global base_chunk, system_prompt | |
| query = (payload.query or "").strip() | |
| if not query: | |
| raise HTTPException(status_code=400, detail="Query cannot be empty") | |
| history_text = _format_history(payload.history) | |
| full_prompt, _context = construct_prompt( | |
| base_chunk, system_prompt, query, history_text | |
| ) | |
| answer = await run_in_threadpool(get_answer, full_prompt) | |
| if not answer: | |
| answer = "Sorry, I failed to get a response from Gemini. Please try again." | |
| return ChatResponse(response=answer, timestamp=datetime.utcnow()) | |
| # Simple health probe | |
| def health(): | |
| return {"status": "ok"} | |
| async def redirect_to_docs(): | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/docs") | |
| if __name__ == "__main__": | |
| import os | |
| # Check environment variable | |
| if not os.getenv("GEMINI_API_KEY"): | |
| print("Warning: GEMINI_API_KEY environment variable not set!") | |
| print("Please set it with: set GEMINI_API_KEY=your_api_key_here") | |
| exit(1) | |
| print("Starting RAG Chat API server...") | |
| uvicorn.run("app", host="0.0.0.0", port=8000, reload=True, log_level="info") | |