File size: 10,469 Bytes
a840639
 
 
 
 
 
 
 
 
b4f3a10
a840639
 
 
 
 
 
f794808
a840639
 
cc5a41a
42c0706
cc5a41a
 
 
 
 
a840639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f3a10
 
a840639
 
 
 
 
 
b4f3a10
 
 
 
 
 
a840639
b4f3a10
a840639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f3a10
a840639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f3a10
a840639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f3a10
a840639
 
 
 
 
 
 
b4f3a10
 
a840639
 
 
 
 
b4f3a10
 
 
 
a840639
b4f3a10
 
 
 
a840639
 
 
 
 
 
 
653db3b
 
 
 
 
a840639
 
 
 
 
 
 
 
 
 
 
653db3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import os
import json
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from datetime import datetime
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from starlette.concurrency import run_in_threadpool
import subprocess, sys


SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

DATA_DIR = os.environ.get("DATA_DIR", "/app/data")
# os.makedirs(DATA_DIR, exist_ok=True)

OUTPUT_CHUNKS_FILE = os.path.join(DATA_DIR, "output_chunks.jsonl")
RAG_CONFIG_FILE = os.path.join(DATA_DIR, "rag_prompt_config.jsonl")
FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index")
EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy")


# Pydantic models for API
class ChatResponse(BaseModel):
    response: str
    timestamp: datetime


class ChatRequest(BaseModel):
    query: str
    history: list[dict] | None = None  # optional conversation history


# Initialize Gemini API
# API key is configured during app startup (lifespan) to avoid import-time failures


# Lifespan function to handle startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup (no DB setup anymore)
    print("Starting RAG Chat API (stateless, no database)...")

    API_KEY = os.getenv("GEMINI_API_KEY")
    if not API_KEY:
        raise RuntimeError("Please set GEMINI_API_KEY environment variable")
    genai.configure(api_key=API_KEY)

    success, chunks_count = initialize_system()
    if success:
        print(f"✅ RAG system initialized with {chunks_count} chunks")
        print("API ready at: http://localhost:8000 (docs at /docs)")
    else:
        raise RuntimeError("System initialization failed")

    yield

    print("Shutting down RAG Chat API...")


# Initialize FastAPI app with lifespan
app = FastAPI(
    title="RAG Chat API",
    description="RAG Chat System with Database Integration",
    lifespan=lifespan,
)

"""API only module.

The web chat UI has been moved to a separate app (see web_app.py). This file now
exposes only the JSON API endpoints so it can be containerised independently or
scaled separately from the frontend.
"""


# Enable CORS for local dev
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost",
        "http://localhost:3000",
        "http://127.0.0.1:3000",
        "http://localhost:8001",  # web UI container
        "http://127.0.0.1:8001",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variables to store precomputed data
chunks_data = None
chunk_embeddings = None
faiss_index = None
base_chunk = None
system_prompt = None
model_embedding = None


# Removed database session dependency (stateless mode)


def load_chunks(json_file):
    """Load chunks from JSON file"""
    try:
        with open(json_file, "r", encoding="utf-8") as file:
            return json.load(file)
    except FileNotFoundError:
        raise FileNotFoundError(f"File {json_file} not found!")
    except json.JSONDecodeError:
        raise ValueError(f"Invalid JSON in {json_file}")


def compute_and_cache_embeddings(chunks):
    """Compute embeddings for all chunks and cache them"""
    global chunk_embeddings, faiss_index

    print("Computing embeddings for all chunks...")
    texts = [chunk["content"] for chunk in chunks]

    # Load or compute embeddings
    if os.path.exists(EMBEDDINGS_FILE):
        print("Loading cached embeddings...")
        chunk_embeddings = np.load(EMBEDDINGS_FILE)
        if chunk_embeddings.shape[0] != len(texts):
            print("Cached embeddings count mismatches chunks. Recomputing...")
            chunk_embeddings = model_embedding.encode(
                texts, convert_to_numpy=True
            ).astype("float32")
            np.save(EMBEDDINGS_FILE, chunk_embeddings)
    else:
        print("Computing new embeddings (this may take a moment)...")
        chunk_embeddings = model_embedding.encode(texts, convert_to_numpy=True).astype(
            "float32"
        )
        np.save(EMBEDDINGS_FILE, chunk_embeddings)
        print("Embeddings cached for future use.")

    # Normalize embeddings (for cosine similarity with IndexFlatIP)
    faiss.normalize_L2(chunk_embeddings)

    # Create or load FAISS index
    embedding_dim = chunk_embeddings.shape[1]
    if os.path.exists(FAISS_INDEX_FILE):
        print("Loading cached FAISS index...")
        faiss_index = faiss.read_index(FAISS_INDEX_FILE)
        # Validate index matches embeddings
        if getattr(faiss_index, "ntotal", 0) != chunk_embeddings.shape[0]:
            print("FAISS index size mismatches embeddings. Rebuilding index...")
            faiss_index = faiss.IndexFlatIP(embedding_dim)
            faiss_index.add(chunk_embeddings)
            faiss.write_index(faiss_index, FAISS_INDEX_FILE)
    else:
        print("Creating new FAISS index...")
        faiss_index = faiss.IndexFlatIP(embedding_dim)
        faiss_index.add(chunk_embeddings)
        faiss.write_index(faiss_index, FAISS_INDEX_FILE)
        print("FAISS index cached for future use.")


def retrieve_relevant_chunks(query, top_k=3):
    """Retrieve most relevant chunks for the query using precomputed embeddings"""
    global chunks_data, faiss_index

    if faiss_index is None or chunks_data is None:
        raise RuntimeError("RAG index not initialized")

    # Encode query
    query_embedding = model_embedding.encode([query], convert_to_numpy=True).astype(
        "float32"
    )
    faiss.normalize_L2(query_embedding)

    top_k = min(top_k, len(chunks_data))
    # Search in precomputed index
    _, indices = faiss_index.search(query_embedding, top_k)
    return [chunks_data[i] for i in indices[0]]


def _format_history(history: list[dict] | None, max_turns: int = 6) -> str:
    """Format recent conversation history for inclusion in the prompt."""
    if not history:
        return ""
    recent = history[-max_turns:]
    lines = []
    for turn in recent:
        role = turn.get("role", "user")
        msg = (turn.get("message") or "").strip()
        if not msg:
            continue
        prefix = "User" if role == "user" else "Assistant"
        lines.append(f"{prefix}: {msg}")
    return "\n".join(lines)


def construct_prompt(base_chunk, system_prompt, query, history_text: str = ""):
    """Construct the full prompt with relevant context"""
    relevant_chunks = retrieve_relevant_chunks(query)
    context = "\n\n".join(chunk["content"] for chunk in relevant_chunks)
    full_prompt = (
        f"System prompt:\n{system_prompt['content']}\n\n"
        f"Context:\n{context}\n\n"
        f"{base_chunk['content']}\n\n"
    )
    if history_text:
        full_prompt += f"Recent conversation:\n{history_text}\n\n"
    full_prompt += f"Query:\n{query}"
    return full_prompt, context


def get_answer(prompt):
    """Get answer from Gemini API"""
    try:
        model = genai.GenerativeModel("gemini-2.5-flash")
        response = model.generate_content(prompt)
        return response.text
    except Exception as e:
        print(f"Error getting response from Gemini: {e}")
        return None


def run_generate_rag_data():
    """Run the data generation script if available."""
    script_path = os.path.join(SCRIPT_DIR, "generate_rag_data.py")
    if not os.path.isfile(script_path):
        print("generate_rag_data.py not found; skipping automatic generation.")
        return
    print("Running generate_rag_data.py to build RAG data...")
    try:
        subprocess.run([sys.executable, script_path], cwd=SCRIPT_DIR, check=True)
        print("generate_rag_data.py completed.")
    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"generate_rag_data.py failed (exit {e.returncode})") from e


def initialize_system():
    """Initialize the RAG system with precomputed embeddings (stateless)."""
    global chunks_data, base_chunk, system_prompt, model_embedding
    try:
        need_generation = (
            not os.path.exists(EMBEDDINGS_FILE)
            or not os.path.exists(OUTPUT_CHUNKS_FILE)
            or not os.path.exists(RAG_CONFIG_FILE)
        )
        if need_generation:
            print("RAG data or embeddings missing. Triggering data generation...")
            run_generate_rag_data()

        print("Loading embedding model...")
        model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")

        print("Loading chunks and configuration...")
        chunks_data = load_chunks(OUTPUT_CHUNKS_FILE)
        config = load_chunks(RAG_CONFIG_FILE)[0]
        base_chunk = config["base_chunk"]
        system_prompt = config["system_prompt"]
        print(f"Loaded {len(chunks_data)} chunks from knowledge base")

        compute_and_cache_embeddings(chunks_data)
        print("System initialized successfully (stateless mode)")
        return True, len(chunks_data)
    except Exception as e:
        print(f"Failed to initialize system: {e}")
        return False, 0


@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(payload: ChatRequest):
    """Chat endpoint that processes queries (no persistence)."""
    global base_chunk, system_prompt
    query = (payload.query or "").strip()
    if not query:
        raise HTTPException(status_code=400, detail="Query cannot be empty")

    history_text = _format_history(payload.history)
    full_prompt, _context = construct_prompt(
        base_chunk, system_prompt, query, history_text
    )

    answer = await run_in_threadpool(get_answer, full_prompt)
    if not answer:
        answer = "Sorry, I failed to get a response from Gemini. Please try again."
    return ChatResponse(response=answer, timestamp=datetime.utcnow())


# Simple health probe
@app.get("/health")
def health():
    return {"status": "ok"}

@app.get("/")
async def redirect_to_docs():
    from fastapi.responses import RedirectResponse
    return RedirectResponse(url="/docs")


if __name__ == "__main__":
    import os

    # Check environment variable
    if not os.getenv("GEMINI_API_KEY"):
        print("Warning: GEMINI_API_KEY environment variable not set!")
        print("Please set it with: set GEMINI_API_KEY=your_api_key_here")
        exit(1)

    print("Starting RAG Chat API server...")
    uvicorn.run("app", host="0.0.0.0", port=8000, reload=True, log_level="info")