Spaces:
Runtime error
Runtime error
| import os | |
| import contextlib | |
| from collections import defaultdict | |
| from typing import Dict, List | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| import torch | |
| import gradio as gr | |
| from ahocorapy.keywordtree import KeywordTree | |
| from sentence_transformers import SentenceTransformer | |
| from FlagEmbedding import FlagModel | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch.nn.functional as F | |
| CSV_PATH = os.environ.get("CORPUS_CSV", "H_and_M_FINAL.csv") # pre‑indexed corpus | |
| TEXT_COL = os.environ.get("TEXT_COLUMN", "text") # column with passage text | |
| IMAGE_COL = os.environ.get("IMAGE_URL_COLUMN", "image_url") # optional image column | |
| TOP_K = int(os.environ.get("TOP_K", 5)) | |
| MAX_TOKENS = int(os.environ.get("MAX_TOKENS", 512)) # truncate long docs | |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 8)) | |
| MODEL_REPO_MAP = { | |
| "intfloat/e5-small-v2": "intfloat/e5-small-v2", | |
| "BAAI/bge-small-en-v1.5": "BAAI/bge-small-en-v1.5", | |
| } | |
| def inference_mode(): | |
| with torch.inference_mode(): | |
| yield | |
| def truncate(text: str, max_tokens: int = MAX_TOKENS) -> str: | |
| """Very rough truncation by characters (≈ tokens/4).""" | |
| approx_chars = max_tokens * 4 # over‑estimate | |
| return text[:approx_chars] | |
| # class EmbeddingBackend: | |
| # """Wraps different HF / FlagEmbedding models behind a common API.""" | |
| # def __init__(self, repo: str): | |
| # self.repo = repo | |
| # if repo == "BAAI/bge-small-en-v1.5": | |
| # # FlagEmbedding back‑end (BGE) | |
| # self.model = FlagModel( | |
| # repo, | |
| # query_instruction_for_retrieval="Generate a representation for this sentence to retrieve related articles:", | |
| # use_fp16=True, | |
| # ) | |
| # self.encode_docs = self.model.encode | |
| # self.encode_query = lambda q: self.model.encode_queries([q])[0] | |
| # else: | |
| # # SentenceTransformer back‑ends | |
| # self.model = SentenceTransformer(repo, trust_remote_code=True) | |
| # if "Qwen3" in repo: | |
| # self.encode_query = lambda q: self.model.encode(q, prompt_name="query") | |
| # elif "stella" in repo: | |
| # self.encode_query = lambda q: self.model.encode(q, prompt_name="s2p_query") | |
| # else: | |
| # self.encode_query = lambda q: self.model.encode(q) | |
| # self.encode_docs = lambda docs: self.model.encode(docs) | |
| # # Convenience wrappers that return *numpy* arrays | |
| # def encode_corpus(self, passages: List[str]) -> np.ndarray: | |
| # emb = self.encode_docs(passages) | |
| # return np.asarray(emb) | |
| # def encode_question(self, question: str) -> np.ndarray: | |
| # emb = self.encode_query(question) | |
| # return np.asarray(emb) | |
| class EmbeddingBackend: | |
| """Adapter that presents .encode_query / .encode_docs for all models.""" | |
| def __init__(self, repo: str): | |
| self.repo = repo | |
| # ---------- BGE (FlagEmbedding) ---------- | |
| if repo == "BAAI/bge-small-en-v1.5": | |
| self.model = FlagModel( | |
| repo, | |
| query_instruction_for_retrieval="Generate a representation for this sentence to retrieve related articles::", | |
| use_fp16=True, | |
| ) | |
| self.encode_docs = lambda docs: self.model.encode(docs, batch_size=BATCH_SIZE) | |
| self.encode_query = lambda q: self.model.encode_queries([q])[0] | |
| return | |
| # ---------- E5 ---------- | |
| if repo == "intfloat/e5-base-v2": | |
| self.tokenizer = AutoTokenizer.from_pretrained(repo) | |
| self.model = AutoModel.from_pretrained(repo) | |
| def _embed(texts: List[str]): | |
| batch_dict = self.tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors="pt") | |
| with inference_mode(): | |
| outputs = self.model(**batch_dict) | |
| hidden = outputs.last_hidden_state.masked_fill(~batch_dict["attention_mask"].bool().unsqueeze(-1), 0.0) | |
| emb = hidden.sum(1) / batch_dict["attention_mask"].sum(1, keepdims=True) | |
| return F.normalize(emb, p=2, dim=1).cpu().numpy() | |
| self.encode_docs = lambda docs: _embed([f"passage: {d}" for d in docs]) | |
| self.encode_query = lambda q: _embed([f"query: {q}"])[0] | |
| return | |
| # ---------- Qwen 0.6B (SentenceTransformer) ---------- | |
| model_kwargs = {} | |
| if "Qwen3" in repo and not os.getenv("QWEN_USE_FLASH"): | |
| model_kwargs["attn_implementation"] = "eager" | |
| self.model = SentenceTransformer(repo, trust_remote_code=True, model_kwargs=model_kwargs) | |
| self.encode_query = lambda q: self.model.encode(q, prompt_name="query") | |
| self.encode_docs = lambda docs: self.model.encode(docs, batch_size=BATCH_SIZE, normalize_embeddings=False) | |
| # ---------- Public wrappers ---------- | |
| def encode_corpus(self, passages: List[str]) -> np.ndarray: | |
| return self.encode_docs(passages) | |
| def encode_question(self, question: str) -> np.ndarray: | |
| return self.encode_query(question) | |
| # -------------------------------------------------- | |
| # Hybrid (exact → semantic) index | |
| # -------------------------------------------------- | |
| class HybridIndex: | |
| def __init__(self, df: pd.DataFrame, text_col: str, backend: EmbeddingBackend): | |
| self.df = df | |
| self.text_col = text_col | |
| self.backend = backend | |
| self.text_to_rows = defaultdict(list) # passage → [row ids] | |
| self.ac_tree = self._build_ac() | |
| self.embeddings = self._build_emb() | |
| # ---------- exact match ---------- | |
| def _build_ac(self): | |
| tree = KeywordTree(case_insensitive=True) | |
| for i, passage in self.df[self.text_col].astype(str).items(): | |
| tree.add(passage) | |
| self.text_to_rows[passage].append(i) | |
| tree.finalize() | |
| return tree | |
| def exact_hits(self, query: str) -> List[int]: | |
| rows = set() | |
| for keyword, _ in self.ac_tree.search_all(query): | |
| rows.update(self.text_to_rows[keyword]) | |
| return list(rows) | |
| # ---------- semantic ---------- | |
| def _build_emb(self): | |
| docs = self.df[self.text_col].astype(str).tolist() | |
| emb = self.backend.encode_corpus(docs) | |
| emb_norm = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
| return emb_norm.astype(np.float32) | |
| def semantic_hits(self, query: str, k: int = TOP_K) -> List[int]: | |
| q = self.backend.encode_question(query) | |
| q = q / np.linalg.norm(q) | |
| scores = self.embeddings @ q # cosine similarities | |
| return np.argsort(-scores)[:k].tolist() | |
| # -------------------------------------------------- | |
| # Build indices at start‑up | |
| # -------------------------------------------------- | |
| def load_corpus(path: str) -> pd.DataFrame: | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Corpus CSV not found: {path}") | |
| df = pd.read_csv(path) | |
| if TEXT_COL not in df.columns: | |
| raise ValueError(f"'{TEXT_COL}' column missing in {path}") | |
| return df | |
| def build_indices(df: pd.DataFrame) -> Dict[str, HybridIndex]: | |
| indices: Dict[str, HybridIndex] = {} | |
| for repo in MODEL_REPO_MAP.values(): | |
| print(f"→ Building index for {repo}…", flush=True) | |
| backend = EmbeddingBackend(repo) | |
| indices[repo] = HybridIndex(df, TEXT_COL, backend) | |
| return indices | |
| print("Loading corpus & initialising indices… (first run may take several minutes)") | |
| CORPUS_DF = load_corpus(CSV_PATH) | |
| INDICES = build_indices(CORPUS_DF) | |
| print("✅ All indices ready.") | |
| # -------------------------------------------------- | |
| # Search handler | |
| # -------------------------------------------------- | |
| def search(query: str, model_repo: str): | |
| if not query: | |
| raise gr.Error("Please enter a query.") | |
| if model_repo not in INDICES: | |
| raise gr.Error("Selected model is not indexed.") | |
| idx = INDICES[model_repo] | |
| rows = idx.exact_hits(query) | |
| if not rows: | |
| rows = idx.semantic_hits(query) | |
| subset_cols = [TEXT_COL] | |
| if IMAGE_COL and IMAGE_COL in CORPUS_DF.columns: | |
| subset_cols.append(IMAGE_COL) | |
| result_df = CORPUS_DF.iloc[rows][subset_cols] | |
| # -------- image gallery -------- | |
| gallery = [] | |
| if IMAGE_COL and IMAGE_COL in result_df.columns: | |
| for url in result_df[IMAGE_COL].dropna(): | |
| try: | |
| requests.head(url, timeout=2) | |
| gallery.append(url) | |
| except requests.RequestException: | |
| continue | |
| return result_df, gallery | |
| # -------------------------------------------------- | |
| # Gradio UI | |
| # -------------------------------------------------- | |
| with gr.Blocks(title="Hybrid RAG Search") as demo: | |
| gr.Markdown( | |
| """ | |
| # Hybrid Retrieval‑Augmented Search | |
| The dataset is pre‑indexed for **Qwen3‑0.6B**, **bge‑small‑en‑v1.5**, and **Stella‑1.5B‑v5**. | |
| * **Exact substring** match via Aho‑Corasick first. | |
| * **Semantic** top‑5 retrieval if no exact hit is found. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_sel = gr.Dropdown( | |
| choices=list(MODEL_REPO_MAP.keys()), | |
| label="Embedding Model", | |
| value="BAAI/bge-small-en-v1.5", | |
| ) | |
| query_box = gr.Textbox(label="Ask a question…", lines=2) | |
| search_btn = gr.Button("Search", variant="primary") | |
| results = gr.Dataframe(interactive=False) | |
| gallery = gr.Gallery(label="Images", columns=4, height="auto") | |
| search_btn.click(search, inputs=[query_box, model_sel], outputs=[results, gallery]) | |
| if __name__ == "__main__": | |
| demo.launch() | |