mryt66 commited on
Commit
b4f3a10
·
1 Parent(s): a840639

Initial commit

Browse files
Files changed (1) hide show
  1. api.py +24 -95
api.py CHANGED
@@ -7,10 +7,8 @@ import numpy as np
7
  from datetime import datetime
8
  from contextlib import asynccontextmanager
9
 
10
- from fastapi import FastAPI, Depends, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
- from sqlalchemy import Column, Integer, Text, DateTime, create_engine
13
- from sqlalchemy.orm import declarative_base, sessionmaker, Session
14
  from pydantic import BaseModel
15
  import uvicorn
16
  from starlette.concurrency import run_in_threadpool
@@ -18,7 +16,7 @@ import subprocess, sys
18
 
19
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
20
 
21
- # Always use local data directory (no env var logic)
22
  DATA_DIR = os.path.join(SCRIPT_DIR, "data")
23
  os.makedirs(DATA_DIR, exist_ok=True)
24
 
@@ -26,25 +24,6 @@ OUTPUT_CHUNKS_FILE = os.path.join(SCRIPT_DIR, "output_chunks.jsonl")
26
  RAG_CONFIG_FILE = os.path.join(SCRIPT_DIR, "rag_prompt_config.jsonl")
27
  FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index")
28
  EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy")
29
- DATABASE_URL = f"sqlite:///{os.path.join(DATA_DIR, 'conversations.db')}"
30
-
31
- Base = declarative_base()
32
- engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
33
- SessionLocal = sessionmaker(bind=engine)
34
-
35
-
36
- # Database model
37
- class Conversation(Base):
38
- __tablename__ = "conversations"
39
-
40
- id = Column(Integer, primary_key=True, index=True)
41
- query = Column(Text)
42
- response = Column(Text)
43
- context = Column(Text)
44
- base_context = Column(Text)
45
- system_prompt = Column(Text)
46
- full_prompt = Column(Text)
47
- timestamp = Column(DateTime, default=datetime.utcnow)
48
 
49
 
50
  # Pydantic models for API
@@ -65,34 +44,23 @@ class ChatRequest(BaseModel):
65
  # Lifespan function to handle startup and shutdown
66
  @asynccontextmanager
67
  async def lifespan(app: FastAPI):
68
- # Startup
69
- print("Starting RAG Chat API...")
70
- print(f"SQLite DB path: {os.path.join(DATA_DIR, 'conversations.db')}")
71
- # Ensure tables now that directory is confirmed writable
72
- Base.metadata.create_all(bind=engine)
73
 
74
- # Configure Gemini here (fail early but at startup)
75
  API_KEY = os.getenv("GEMINI_API_KEY")
76
  if not API_KEY:
77
  raise RuntimeError("Please set GEMINI_API_KEY environment variable")
78
  genai.configure(api_key=API_KEY)
79
 
80
- try:
81
- success, chunks_count = initialize_system()
82
- if success:
83
- print(f" RAG system initialized successfully with {chunks_count} chunks")
84
- print("API ready at: http://localhost:8000")
85
- print("API documentation at: http://localhost:8000/docs")
86
- else:
87
- print("❌ Failed to initialize RAG system")
88
- raise RuntimeError("System initialization failed")
89
- except Exception as e:
90
- print(f"❌ Initialization error: {str(e)}")
91
- raise RuntimeError(f"System initialization failed: {str(e)}")
92
 
93
- yield # This is where the app runs
94
 
95
- # Shutdown (if needed)
96
  print("Shutting down RAG Chat API...")
97
 
98
 
@@ -135,13 +103,7 @@ system_prompt = None
135
  model_embedding = None
136
 
137
 
138
- # Dependency to get database session
139
- def get_db():
140
- db = SessionLocal()
141
- try:
142
- yield db
143
- finally:
144
- db.close()
145
 
146
 
147
  def load_chunks(json_file):
@@ -278,11 +240,9 @@ def run_generate_rag_data():
278
 
279
 
280
  def initialize_system():
281
- """Initialize the RAG system with precomputed embeddings"""
282
  global chunks_data, base_chunk, system_prompt, model_embedding
283
-
284
  try:
285
- # If embeddings or required JSON files are missing, (re)generate data first.
286
  need_generation = (
287
  not os.path.exists(EMBEDDINGS_FILE)
288
  or not os.path.exists(OUTPUT_CHUNKS_FILE)
@@ -292,72 +252,41 @@ def initialize_system():
292
  print("RAG data or embeddings missing. Triggering data generation...")
293
  run_generate_rag_data()
294
 
295
- # Initialize embedding model
296
  print("Loading embedding model...")
297
  model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
298
 
299
- # Load configurations
300
  print("Loading chunks and configuration...")
301
  chunks_data = load_chunks(OUTPUT_CHUNKS_FILE)
302
  config = load_chunks(RAG_CONFIG_FILE)[0]
303
  base_chunk = config["base_chunk"]
304
  system_prompt = config["system_prompt"]
305
-
306
  print(f"Loaded {len(chunks_data)} chunks from knowledge base")
307
 
308
- # Precompute embeddings once (will compute if file absent)
309
  compute_and_cache_embeddings(chunks_data)
310
-
311
- print("System initialized successfully!")
312
  return True, len(chunks_data)
313
-
314
  except Exception as e:
315
  print(f"Failed to initialize system: {e}")
316
  return False, 0
317
 
318
 
319
  @app.post("/chat", response_model=ChatResponse)
320
- async def chat_endpoint(payload: ChatRequest, db: Session = Depends(get_db)):
321
- """Chat endpoint that processes queries and saves conversations to database
322
-
323
- Accepts a JSON body: {"query": "..."
324
- """
325
  global base_chunk, system_prompt
326
-
327
  query = (payload.query or "").strip()
328
  if not query:
329
  raise HTTPException(status_code=400, detail="Query cannot be empty")
330
 
331
- try:
332
- # Construct prompt and get answer
333
- history_text = _format_history(payload.history)
334
- full_prompt, context = construct_prompt(
335
- base_chunk, system_prompt, query, history_text
336
- )
337
-
338
- # Avoid blocking the event loop with a sync network call
339
- answer = await run_in_threadpool(get_answer, full_prompt)
340
- if not answer:
341
- answer = "Sorry, I failed to get a response from Gemini. Please try again."
342
-
343
- # Save conversation to database
344
- conversation = Conversation(
345
- query=query,
346
- response=answer,
347
- context=context,
348
- base_context=base_chunk["content"],
349
- system_prompt=system_prompt["content"],
350
- full_prompt=full_prompt,
351
- )
352
-
353
- db.add(conversation)
354
- db.commit()
355
-
356
- return ChatResponse(response=answer, timestamp=conversation.timestamp)
357
 
358
- except Exception as e:
359
- db.rollback()
360
- raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
 
361
 
362
 
363
  # Simple health probe
 
7
  from datetime import datetime
8
  from contextlib import asynccontextmanager
9
 
10
+ from fastapi import FastAPI, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
 
 
12
  from pydantic import BaseModel
13
  import uvicorn
14
  from starlette.concurrency import run_in_threadpool
 
16
 
17
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
18
 
19
+ # Always use local data directory (no env var logic and no DB)
20
  DATA_DIR = os.path.join(SCRIPT_DIR, "data")
21
  os.makedirs(DATA_DIR, exist_ok=True)
22
 
 
24
  RAG_CONFIG_FILE = os.path.join(SCRIPT_DIR, "rag_prompt_config.jsonl")
25
  FAISS_INDEX_FILE = os.path.join(DATA_DIR, "faiss_index.index")
26
  EMBEDDINGS_FILE = os.path.join(DATA_DIR, "chunk_embeddings.npy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  # Pydantic models for API
 
44
  # Lifespan function to handle startup and shutdown
45
  @asynccontextmanager
46
  async def lifespan(app: FastAPI):
47
+ # Startup (no DB setup anymore)
48
+ print("Starting RAG Chat API (stateless, no database)...")
 
 
 
49
 
 
50
  API_KEY = os.getenv("GEMINI_API_KEY")
51
  if not API_KEY:
52
  raise RuntimeError("Please set GEMINI_API_KEY environment variable")
53
  genai.configure(api_key=API_KEY)
54
 
55
+ success, chunks_count = initialize_system()
56
+ if success:
57
+ print(f"✅ RAG system initialized with {chunks_count} chunks")
58
+ print("API ready at: http://localhost:8000 (docs at /docs)")
59
+ else:
60
+ raise RuntimeError("System initialization failed")
 
 
 
 
 
 
61
 
62
+ yield
63
 
 
64
  print("Shutting down RAG Chat API...")
65
 
66
 
 
103
  model_embedding = None
104
 
105
 
106
+ # Removed database session dependency (stateless mode)
 
 
 
 
 
 
107
 
108
 
109
  def load_chunks(json_file):
 
240
 
241
 
242
  def initialize_system():
243
+ """Initialize the RAG system with precomputed embeddings (stateless)."""
244
  global chunks_data, base_chunk, system_prompt, model_embedding
 
245
  try:
 
246
  need_generation = (
247
  not os.path.exists(EMBEDDINGS_FILE)
248
  or not os.path.exists(OUTPUT_CHUNKS_FILE)
 
252
  print("RAG data or embeddings missing. Triggering data generation...")
253
  run_generate_rag_data()
254
 
 
255
  print("Loading embedding model...")
256
  model_embedding = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
257
 
 
258
  print("Loading chunks and configuration...")
259
  chunks_data = load_chunks(OUTPUT_CHUNKS_FILE)
260
  config = load_chunks(RAG_CONFIG_FILE)[0]
261
  base_chunk = config["base_chunk"]
262
  system_prompt = config["system_prompt"]
 
263
  print(f"Loaded {len(chunks_data)} chunks from knowledge base")
264
 
 
265
  compute_and_cache_embeddings(chunks_data)
266
+ print("System initialized successfully (stateless mode)")
 
267
  return True, len(chunks_data)
 
268
  except Exception as e:
269
  print(f"Failed to initialize system: {e}")
270
  return False, 0
271
 
272
 
273
  @app.post("/chat", response_model=ChatResponse)
274
+ async def chat_endpoint(payload: ChatRequest):
275
+ """Chat endpoint that processes queries (no persistence)."""
 
 
 
276
  global base_chunk, system_prompt
 
277
  query = (payload.query or "").strip()
278
  if not query:
279
  raise HTTPException(status_code=400, detail="Query cannot be empty")
280
 
281
+ history_text = _format_history(payload.history)
282
+ full_prompt, _context = construct_prompt(
283
+ base_chunk, system_prompt, query, history_text
284
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ answer = await run_in_threadpool(get_answer, full_prompt)
287
+ if not answer:
288
+ answer = "Sorry, I failed to get a response from Gemini. Please try again."
289
+ return ChatResponse(response=answer, timestamp=datetime.utcnow())
290
 
291
 
292
  # Simple health probe