import os import sys import html # ============================================================================ # CRITICAL: Set API key BEFORE importing Together # ============================================================================ # Check API key and set TOGETHER_API_KEY environment variable API_KEY = os.environ.get("pilotikval") if not API_KEY: print("❌ Missing 'pilotikval' environment variable. Please set your TogetherAI API key.") sys.exit(1) # Set TOGETHER_API_KEY for the Together client os.environ["TOGETHER_API_KEY"] = API_KEY # NOW import Together and other dependencies import streamlit as st import streamlit.components.v1 from together import Together from langchain_community.vectorstores import Chroma from langchain_huggingface import HuggingFaceEmbeddings from huggingface_hub import snapshot_download # ============================================================================ # CONFIGURATION # ============================================================================ # Your HuggingFace dataset repository containing all vector stores DATASET_REPO = "Sbnos/vstoryies" # Vector store configurations VECTOR_STORES = { "General Medicine": { "collection_name": "oxfordmed", "persist_directory": "oxfordmedbookdir" }, "Paediatrics": { "collection_name": "paedia", "persist_directory": "nelsonpaedia" }, "Respiratory": { "collection_name": "respmurraynotes", "persist_directory": "respmurray" }, "Dermatology": { "collection_name": "derma", "persist_directory": "rookderma" }, "Endocrine": { "collection_name": "endocrine", "persist_directory": "williamsendocrine" }, "Gastroenterology": { "collection_name": "gastro", "persist_directory": "yamadagastro" }, "Surgery": { "collection_name": "gensurgery", "persist_directory": "baileysurgery" }, "Neurology": { "collection_name": "neuro", "persist_directory": "bradleyneuro" }, "Cardiology": { "collection_name": "cardiobraun", "persist_directory": "braunwaldcardiofin" }, "Nephrology": { "collection_name": "nephro", "persist_directory": "brennernephro" }, "Orthopedics": { "collection_name": "oportho", "persist_directory": "campbellorthop" }, "Rheumatology": { "collection_name": "rheumatology", "persist_directory": "firesteinrheumatology" } } # Model configurations EMBED_MODEL = "BAAI/bge-base-en" LLM_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" RETRIEVAL_K = 26 # ============================================================================ # PAGE CONFIG # ============================================================================ st.set_page_config( page_title="DocChatter Medical RAG", page_icon="🩺", layout="wide" ) # ============================================================================ # INITIALIZATION # ============================================================================ # Initialize TogetherAI client try: client = Together() except Exception as e: st.error(f"❌ Failed to initialize Together client: {e}") st.stop() # Download all vector stores from HuggingFace dataset on first run @st.cache_resource def download_all_vectorstores(): """Download all vector stores from HuggingFace dataset repository""" if not any(os.path.exists(config["persist_directory"]) for config in VECTOR_STORES.values()): with st.spinner("📥 Downloading vector stores from HuggingFace (one-time setup)..."): try: snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", local_dir=".", allow_patterns=["*"] ) st.success("✅ Vector stores downloaded successfully!") except Exception as e: st.error(f"❌ Failed to download vector stores: {e}") st.stop() # Download vector stores if needed download_all_vectorstores() # Initialize embeddings @st.cache_resource def get_embeddings(): return HuggingFaceEmbeddings( model_name=EMBED_MODEL, encode_kwargs={"normalize_embeddings": True} ) embeddings = get_embeddings() # ============================================================================ # SESSION STATE # ============================================================================ if 'chat_history' not in st.session_state: st.session_state.chat_history = [] if 'selected_collection' not in st.session_state: st.session_state.selected_collection = list(VECTOR_STORES.keys())[0] # ============================================================================ # HELPER FUNCTIONS # ============================================================================ @st.cache_resource def load_vectorstore(_embeddings, collection_name, persist_directory): """Load and cache vector store""" vectorstore = Chroma( collection_name=collection_name, persist_directory=persist_directory, embedding_function=_embeddings ) return vectorstore.as_retriever(search_kwargs={"k": RETRIEVAL_K}) def build_system_prompt(context: str) -> dict: """Build system prompt with retrieved context""" prompt = f"""You are an expert medical assistant with access to authoritative medical literature. Your role: - Provide accurate, evidence-based medical information - Answer questions clearly and comprehensively - Ask clarifying questions if needed - Use the context below to support your answers - Be empathetic and professional - Remember previous messages in the conversation Retrieved Context: {context} Instructions: - Base your answers on the provided context - If the context doesn't contain relevant information, acknowledge this - Structure complex answers with clear organization - Cite specific information when referencing the context """ return {"role": "system", "content": prompt} def stream_llm_response(messages): """Stream response from TogetherAI""" response = "" stream = client.chat.completions.create( model=LLM_MODEL, messages=messages, max_tokens=24096, temperature=0.1, stream=True ) for chunk in stream: if hasattr(chunk, 'choices') and len(chunk.choices) > 0: if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: response += chunk.choices[0].delta.content yield response # ============================================================================ # SIDEBAR # ============================================================================ with st.sidebar: st.title("🩺 DocChatter Medical RAG") st.markdown("---") # Collection selector st.subheader("📚 Select Medical Specialty") selected = st.selectbox( "Choose a collection:", options=list(VECTOR_STORES.keys()), index=list(VECTOR_STORES.keys()).index(st.session_state.selected_collection), key="collection_selector" ) if selected != st.session_state.selected_collection: st.session_state.selected_collection = selected st.rerun() st.markdown("---") # Stats st.subheader("📊 Session Info") st.metric("Messages", len(st.session_state.chat_history)) st.metric("Current Collection", selected) st.markdown("---") # Clear button if st.button("🗑️ Clear Chat History", use_container_width=True): st.session_state.chat_history = [] st.rerun() st.markdown("---") st.caption("Powered by TogetherAI & LangChain") # ============================================================================ # MAIN CHAT INTERFACE # ============================================================================ st.title("💬 Medical Document Chat") st.caption(f"Currently using: **{st.session_state.selected_collection}** collection") # Load retriever for selected collection config = VECTOR_STORES[st.session_state.selected_collection] retriever = load_vectorstore( embeddings, config["collection_name"], config["persist_directory"] ) # Display chat history for i, message in enumerate(st.session_state.chat_history): with st.chat_message(message["role"]): st.markdown(message["content"]) # Add one-click copy button for assistant messages if message["role"] == "assistant": escaped_content = html.escape(message["content"]) copy_html = f"""