Sbnos commited on
Commit
bcf3e78
Β·
1 Parent(s): 35e79b2

adding other specialties

Browse files
Files changed (1) hide show
  1. app.py +270 -117
app.py CHANGED
@@ -4,142 +4,295 @@ from together import Together
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
 
7
- # --- Configuration ---
8
- # TogetherAI API key (env var name pilotikval)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  TOGETHER_API_KEY = os.environ.get("pilotikval")
10
  if not TOGETHER_API_KEY:
11
- st.error("Missing pilotikval environment variable.")
12
  st.stop()
13
 
14
  # Initialize TogetherAI client
15
- client = Together(api_key=TOGETHER_API_KEY)
 
 
16
 
17
- # Embeddings setup
18
- EMBED_MODEL_NAME = "BAAI/bge-base-en"
19
- embeddings = HuggingFaceEmbeddings(
20
- model_name=EMBED_MODEL_NAME,
21
- encode_kwargs={"normalize_embeddings": True},
22
- )
23
 
24
- # Sidebar: select collection
25
- st.sidebar.title("DocChatter RAG")
26
- collection = st.sidebar.selectbox(
27
- "Choose a document collection:",
28
- ['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine']
29
- )
 
30
 
31
- dirs = {
32
- 'General Medicine': './oxfordmedbookdir/',
33
- 'RespiratoryFishman': './respfishmandbcud/',
34
- 'RespiratoryMurray': './respmurray/',
35
- 'MedMRCP2': './medmrcp2store/',
36
- 'OldMedicine': './mrcpchromadb/'
37
- }
38
- cols = {
39
- 'General Medicine': 'oxfordmed',
40
- 'RespiratoryFishman': 'fishmannotescud',
41
- 'RespiratoryMurray': 'respmurraynotes',
42
- 'MedMRCP2': 'medmrcp2notes',
43
- 'OldMedicine': 'mrcppassmednotes'
44
- }
45
 
46
- persist_directory = dirs[collection]
47
- collection_name = cols[collection]
 
48
 
49
- # Load Chroma vector store
50
- vectorstore = Chroma(
51
- collection_name=collection_name,
52
- persist_directory=persist_directory,
53
- embedding_function=embeddings
54
- )
55
- retriever = vectorstore.as_retriever(search_kwargs={"k": 20}) # k=20
56
-
57
- # System prompt template
58
-
59
- def build_system(context: str) -> dict:
60
- """
61
- Build a comprehensive system prompt:
62
- - Act as an expert medical assistant and attentive listener.
63
- - Leverage retrieved context to craft detailed, accurate, and empathetic responses.
64
- - Ask clarifying follow-up questions if the user's query is ambiguous.
65
- - Structure answers clearly with headings, bullet points, and step-by-step explanations.
66
- - Cite relevant context sections when appropriate.
67
- - Maintain conversational memory for follow-up continuity.
68
- """
69
- prompt = f"""
70
- You are a world-class medical assistant and conversational partner.
71
- Listen carefully to the user’s questions, reference the context below, and provide a thorough, evidence-based response.
72
- If any part of the question is unclear, ask a clarifying question before proceeding.
73
- Organize your answer with clear headings or bullet points, and refer back to specific context snippets as needed.
74
- Always be empathetic, concise, and precise in your medical explanations.
75
- Retain memory of previous user messages to support follow-up interactions.
76
- === Retrieved Context Start ===
 
 
 
 
 
77
  {context}
78
- === Retrieved Context End ===
 
 
 
 
 
79
  """
80
  return {"role": "system", "content": prompt}
81
 
82
- st.title("🩺 DocChatter RAG (Streaming & Memory)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
84
 
85
- # Initialize chat history
86
- if 'chat_history' not in st.session_state:
87
- st.session_state.chat_history = [] # list of dicts {role, content}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Get user input at top level
90
- user_prompt = st.chat_input("Ask anything about your docs…")
 
91
 
92
- # Tabs for UI
93
- chat_tab, clear_tab = st.tabs(["Chat", "Clear History"])
94
 
95
- with chat_tab:
96
- # Display existing chat
97
- for msg in st.session_state.chat_history:
98
- st.chat_message(msg['role']).write(msg['content'])
 
 
 
99
 
100
- # Handle new user input
101
- if user_prompt:
102
- # Echo user
103
- st.chat_message("user").write(user_prompt)
104
- st.session_state.chat_history.append({"role": "user", "content": user_prompt})
 
 
 
 
 
 
 
 
105
 
106
- # Retrieve top-k documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  try:
108
- docs = retriever.invoke({"query": user_prompt})
109
- except Exception:
110
- docs = retriever.get_relevant_documents(user_prompt)
111
- context = "\n---\n".join([d.page_content for d in docs])
112
-
113
- # Build TogetherAI message sequence
114
- messages = [build_system(context)]
115
- for m in st.session_state.chat_history:
116
- messages.append(m)
117
-
118
- # Stream assistant response
119
- response_container = st.chat_message("assistant")
120
- stream_placeholder = response_container.empty()
121
- answer = ""
122
-
123
- for token in client.chat.completions.create(
124
- model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
125
- messages=messages,
126
- max_tokens=22048,
127
- temperature=0.1,
128
- stream=True
129
- ):
130
- try:
131
- choice = token.choices[0]
132
- delta = getattr(choice.delta, 'content', '')
133
- if delta:
134
- answer += delta
135
- stream_placeholder.write(answer)
136
- except (IndexError, AttributeError):
137
- continue
138
-
139
- # Save assistant response
140
- st.session_state.chat_history.append({"role": "assistant", "content": answer})
141
-
142
- with clear_tab:
143
- if st.button("πŸ—‘οΈ Clear chat history"):
144
- st.session_state.chat_history = []
145
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
 
7
+ # ============================================================================
8
+ # CONFIGURATION
9
+ # ============================================================================
10
+
11
+ # Vector store configurations
12
+ VECTOR_STORES = {
13
+ "Paediatrics": {
14
+ "collection_name": "paedia",
15
+ "persist_directory": "nelsonpaedia"
16
+ },
17
+ "Respiratory": {
18
+ "collection_name": "respmurraynotes",
19
+ "persist_directory": "respmurray"
20
+ },
21
+ "Dermatology": {
22
+ "collection_name": "derma",
23
+ "persist_directory": "rookderma"
24
+ },
25
+ "Endocrine": {
26
+ "collection_name": "endocrine",
27
+ "persist_directory": "williamsendocrine"
28
+ },
29
+ "Gastroenterology": {
30
+ "collection_name": "gastro",
31
+ "persist_directory": "yamadagastro"
32
+ },
33
+ "Surgery": {
34
+ "collection_name": "gensurgery",
35
+ "persist_directory": "baileysurgery"
36
+ },
37
+ "Neurology": {
38
+ "collection_name": "neuro",
39
+ "persist_directory": "bradleyneuro"
40
+ },
41
+ "Cardiology": {
42
+ "collection_name": "cardiobraun",
43
+ "persist_directory": "braunwaldcardiofin"
44
+ },
45
+ "Nephrology": {
46
+ "collection_name": "nephro",
47
+ "persist_directory": "brennernephro"
48
+ },
49
+ "Orthopedics": {
50
+ "collection_name": "oportho",
51
+ "persist_directory": "campbellorthop"
52
+ },
53
+ "Rheumatology": {
54
+ "collection_name": "rheumatology",
55
+ "persist_directory": "firesteinrheumatology"
56
+ }
57
+ }
58
+
59
+ # Model configurations
60
+ EMBED_MODEL = "BAAI/bge-base-en"
61
+ LLM_MODEL = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
62
+ RETRIEVAL_K = 20
63
+
64
+ # ============================================================================
65
+ # PAGE CONFIG
66
+ # ============================================================================
67
+
68
+ st.set_page_config(
69
+ page_title="DocChatter Medical RAG",
70
+ page_icon="🩺",
71
+ layout="wide"
72
+ )
73
+
74
+ # ============================================================================
75
+ # INITIALIZATION
76
+ # ============================================================================
77
+
78
+ # Check API key
79
  TOGETHER_API_KEY = os.environ.get("pilotikval")
80
  if not TOGETHER_API_KEY:
81
+ st.error("❌ Missing 'pilotikval' environment variable. Please set your TogetherAI API key.")
82
  st.stop()
83
 
84
  # Initialize TogetherAI client
85
+ @st.cache_resource
86
+ def get_together_client():
87
+ return Together(api_key=TOGETHER_API_KEY)
88
 
89
+ client = get_together_client()
 
 
 
 
 
90
 
91
+ # Initialize embeddings
92
+ @st.cache_resource
93
+ def get_embeddings():
94
+ return HuggingFaceEmbeddings(
95
+ model_name=EMBED_MODEL,
96
+ encode_kwargs={"normalize_embeddings": True}
97
+ )
98
 
99
+ embeddings = get_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ # ============================================================================
102
+ # SESSION STATE
103
+ # ============================================================================
104
 
105
+ if 'chat_history' not in st.session_state:
106
+ st.session_state.chat_history = []
107
+
108
+ if 'selected_collection' not in st.session_state:
109
+ st.session_state.selected_collection = list(VECTOR_STORES.keys())[0]
110
+
111
+ # ============================================================================
112
+ # HELPER FUNCTIONS
113
+ # ============================================================================
114
+
115
+ @st.cache_resource
116
+ def load_vectorstore(_embeddings, collection_name, persist_directory):
117
+ """Load and cache vector store"""
118
+ vectorstore = Chroma(
119
+ collection_name=collection_name,
120
+ persist_directory=persist_directory,
121
+ embedding_function=_embeddings
122
+ )
123
+ return vectorstore.as_retriever(search_kwargs={"k": RETRIEVAL_K})
124
+
125
+ def build_system_prompt(context: str) -> dict:
126
+ """Build system prompt with retrieved context"""
127
+ prompt = f"""You are an expert medical assistant with access to authoritative medical literature.
128
+
129
+ Your role:
130
+ - Provide accurate, evidence-based medical information
131
+ - Answer questions clearly and comprehensively
132
+ - Ask clarifying questions if needed
133
+ - Use the context below to support your answers
134
+ - Be empathetic and professional
135
+ - Remember previous messages in the conversation
136
+
137
+ Retrieved Context:
138
  {context}
139
+
140
+ Instructions:
141
+ - Base your answers on the provided context
142
+ - If the context doesn't contain relevant information, acknowledge this
143
+ - Structure complex answers with clear organization
144
+ - Cite specific information when referencing the context
145
  """
146
  return {"role": "system", "content": prompt}
147
 
148
+ def stream_llm_response(messages):
149
+ """Stream response from TogetherAI"""
150
+ response = ""
151
+ for chunk in client.chat.completions.create(
152
+ model=LLM_MODEL,
153
+ messages=messages,
154
+ max_tokens=4096,
155
+ temperature=0.1,
156
+ stream=True
157
+ ):
158
+ try:
159
+ if chunk.choices[0].delta.content:
160
+ response += chunk.choices[0].delta.content
161
+ yield response
162
+ except (IndexError, AttributeError):
163
+ continue
164
 
165
+ # ============================================================================
166
+ # SIDEBAR
167
+ # ============================================================================
168
 
169
+ with st.sidebar:
170
+ st.title("🩺 DocChatter Medical RAG")
171
+ st.markdown("---")
172
+
173
+ # Collection selector
174
+ st.subheader("πŸ“š Select Medical Specialty")
175
+ selected = st.selectbox(
176
+ "Choose a collection:",
177
+ options=list(VECTOR_STORES.keys()),
178
+ index=list(VECTOR_STORES.keys()).index(st.session_state.selected_collection),
179
+ key="collection_selector"
180
+ )
181
+
182
+ if selected != st.session_state.selected_collection:
183
+ st.session_state.selected_collection = selected
184
+ st.rerun()
185
+
186
+ st.markdown("---")
187
+
188
+ # Stats
189
+ st.subheader("πŸ“Š Session Info")
190
+ st.metric("Messages", len(st.session_state.chat_history))
191
+ st.metric("Current Collection", selected)
192
+
193
+ st.markdown("---")
194
+
195
+ # Clear button
196
+ if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
197
+ st.session_state.chat_history = []
198
+ st.rerun()
199
+
200
+ st.markdown("---")
201
+ st.caption("Powered by TogetherAI & LangChain")
202
 
203
+ # ============================================================================
204
+ # MAIN CHAT INTERFACE
205
+ # ============================================================================
206
 
207
+ st.title("πŸ’¬ Medical Document Chat")
208
+ st.caption(f"Currently using: **{st.session_state.selected_collection}** collection")
209
 
210
+ # Load retriever for selected collection
211
+ config = VECTOR_STORES[st.session_state.selected_collection]
212
+ retriever = load_vectorstore(
213
+ embeddings,
214
+ config["collection_name"],
215
+ config["persist_directory"]
216
+ )
217
 
218
+ # Display chat history
219
+ for i, message in enumerate(st.session_state.chat_history):
220
+ with st.chat_message(message["role"]):
221
+ st.markdown(message["content"])
222
+
223
+ # Add copy button for assistant messages
224
+ if message["role"] == "assistant":
225
+ st.button(
226
+ "πŸ“‹ Copy",
227
+ key=f"copy_{i}",
228
+ on_click=lambda msg=message["content"]: st.toast("Copied to clipboard! (Use Ctrl+C to copy manually)"),
229
+ help="Click to copy this response"
230
+ )
231
 
232
+ # Chat input
233
+ user_input = st.chat_input("Ask me anything about medical topics...")
234
+
235
+ if user_input:
236
+ # Add user message
237
+ st.session_state.chat_history.append({
238
+ "role": "user",
239
+ "content": user_input
240
+ })
241
+
242
+ # Display user message
243
+ with st.chat_message("user"):
244
+ st.markdown(user_input)
245
+
246
+ # Retrieve relevant documents
247
+ with st.spinner("πŸ” Searching medical literature..."):
248
  try:
249
+ docs = retriever.invoke(user_input)
250
+ except:
251
+ docs = retriever.get_relevant_documents(user_input)
252
+
253
+ context = "\n\n---\n\n".join([doc.page_content for doc in docs])
254
+
255
+ # Build messages for LLM
256
+ messages = [build_system_prompt(context)]
257
+
258
+ # Add chat history
259
+ for msg in st.session_state.chat_history:
260
+ messages.append({
261
+ "role": msg["role"],
262
+ "content": msg["content"]
263
+ })
264
+
265
+ # Stream assistant response
266
+ with st.chat_message("assistant"):
267
+ response_placeholder = st.empty()
268
+ full_response = ""
269
+
270
+ for response_chunk in stream_llm_response(messages):
271
+ full_response = response_chunk
272
+ response_placeholder.markdown(full_response + "β–Œ")
273
+
274
+ response_placeholder.markdown(full_response)
275
+
276
+ # Add copy button
277
+ copy_button_key = f"copy_{len(st.session_state.chat_history)}"
278
+ st.button(
279
+ "πŸ“‹ Copy",
280
+ key=copy_button_key,
281
+ on_click=lambda: st.toast("Response ready to copy! (Use Ctrl+C)"),
282
+ help="Click to copy this response"
283
+ )
284
+
285
+ # Save assistant response
286
+ st.session_state.chat_history.append({
287
+ "role": "assistant",
288
+ "content": full_response
289
+ })
290
+
291
+ st.rerun()
292
+
293
+ # ============================================================================
294
+ # FOOTER
295
+ # ============================================================================
296
+
297
+ st.markdown("---")
298
+ st.caption("⚠️ This is an AI assistant. Always consult qualified healthcare professionals for medical advice.")