Spaces:
Runtime error
Runtime error
| # api.py | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List, Tuple | |
| from prompt_search_engine import PromptSearchEngine | |
| from vectorizer import Vectorizer | |
| from datasets import load_dataset | |
| # Define the request and response models | |
| class QueryRequest(BaseModel): | |
| query: str | |
| n: int = 5 # default value | |
| class QueryResponse(BaseModel): | |
| results: List[Tuple[float, str]] | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| search_engine = None | |
| # Load prompts and initialize the search engine when the app starts | |
| def startup_event(): | |
| global search_engine | |
| # Load the prompts | |
| dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts") | |
| prompts = dataset["train"]["Prompt"] | |
| # Initialize vectorizer with the default model | |
| vectorizer = Vectorizer(model="all-MiniLM-L6-v2") | |
| # Initialize the search engine | |
| search_engine = PromptSearchEngine(prompts, vectorizer) | |
| # Define the /search endpoint | |
| def search_prompts(request: QueryRequest): | |
| global search_engine | |
| if search_engine is None: | |
| return {"results": []} | |
| # Get the top-n most similar prompts | |
| similar_prompts = search_engine.most_similar(query=request.query, n=request.n) | |
| # Prepare the response | |
| results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts] | |
| return {"results": results} | |