File size: 7,369 Bytes
6ca4b94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import io
import os
import asyncio
import functools
import logging
import polars as pl
import numpy as np
from functools import lru_cache
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pathlib import Path
from cryptography.fernet import Fernet
from src.config import config
from src.utils.logging import context_logger
logger = logging.getLogger(__name__)
DEV_MODE = bool(os.getenv("DEV_MODE", "false").lower())
DATA_CACHE_DIR = Path("cache/processed_data")
def _load_from_cache():
try:
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
meta_path = DATA_CACHE_DIR / "meta.parquet"
item_path = DATA_CACHE_DIR / "item_centroids.npy"
scale_path = DATA_CACHE_DIR / "scale_centroids.npy"
if not all(p.exists() for p in [meta_path, item_path, scale_path]):
logger.info("Cache not found")
return None
return {
'meta': pl.read_parquet(meta_path),
'item_centroids': np.load(item_path),
'scale_centroids': np.load(scale_path),
}
except Exception as e:
logger.warning(f"Failed to load from cache: {e}")
return None
def _save_to_cache(data):
try:
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
data['meta'].write_parquet(DATA_CACHE_DIR / "meta.parquet")
np.save(DATA_CACHE_DIR / "item_centroids.npy", data['item_centroids'])
np.save(DATA_CACHE_DIR / "scale_centroids.npy", data['scale_centroids'])
except Exception as e:
logger.warning(f"Failed to save to cache: {e}")
async def load_search_data(app: FastAPI):
with context_logger(f"πΎ Loading search database"):
encryption_key = config.data.encryption_key
if not encryption_key:
logger.error("DATA_ENCRYPTION_KEY not found")
app.state.search_data = None
return
try:
loop = asyncio.get_event_loop()
if DEV_MODE:
cached_data = await loop.run_in_executor(None, _load_from_cache)
if cached_data is not None:
app.state.data = cached_data
logger.info("β
Loaded from cache (dev mode)")
logger.info(f"`data['meta']` shape: {app.state.data['meta'].shape}")
logger.info(f"`data['item_centroids']` shape: {app.state.data['item_centroids'].shape}")
logger.info(f"`data['scale_centroids']` shape: {app.state.data['scale_centroids'].shape}")
return
df = await loop.run_in_executor(None, _load_dataset_sync, config.data.dataset_path, config.data.encryption_key)
app.state.data = {
'meta': df.drop('item_centroid', 'scale_centroid'),
'item_centroids': np.vstack(df['item_centroid'].to_list()),
'scale_centroids': np.vstack(df['scale_centroid'].to_list()),
}
if DEV_MODE:
await loop.run_in_executor(None, _save_to_cache, app.state.data)
logger.info("β
Saved to cache (dev mode)")
logger.info(f"`data['meta']` shape:{ app.state.data['meta'].shape }")
logger.info(f"`data['item_centroids']` shape:{ app.state.data['item_centroids'].shape }")
logger.info(f"`data['scale_centroids']` shape:{ app.state.data['scale_centroids'].shape }")
except Exception as e:
logger.error(f"Error loading search data: {e}")
app.state.search_data = None
def _load_dataset_sync(dataset_path: str, encryption_key: str) -> pl.DataFrame:
import pickle
import pandas as pd
from datasets import load_dataset
dataset = load_dataset(dataset_path, split="train")
cipher = Fernet(encryption_key)
decrypted_rows = []
for row in dataset:
decrypted_row = {
col: pickle.loads(cipher.decrypt(row[col]))
for col in row.keys()
}
decrypted_rows.append(decrypted_row)
df = pd.DataFrame(decrypted_rows)
return pl.from_pandas(df)
# @lru_cache()
async def setup_marimo(app: FastAPI):
"""Setup Marimo ASGI app during startup"""
try:
logger.info("π Loading Marimo ASGI app...")
import marimo as mo
notebook_path = Path("notebooks/marimo-demo.py")
if not notebook_path.exists():
logger.warning(f"Notebook not found at {notebook_path}")
app.state.marimo_app = None
return
# Run marimo setup in executor to avoid blocking
loop = asyncio.get_event_loop()
marimo_app = await loop.run_in_executor(
None,
lambda: setup_marimo_sync(notebook_path)
)
if marimo_app:
app.mount("/marimo", marimo_app)
app.state.marimo_app = marimo_app
logger.info("β
Marimo ASGI app mounted at /marimo")
else:
app.state.marimo_app = None
logger.error("β Failed to mount Marimo ASGI app")
except Exception as e:
logger.error(f"β Error setting up marimo: {e}")
app.state.marimo_app = None
# @lru_cache()
def setup_marimo_sync(notebook_path):
"""Synchronous marimo setup function"""
import marimo as mo
try:
marimo_server = mo.create_asgi_app()
marimo_server = marimo_server.with_app(
path="/demo",
root=str(notebook_path.absolute()),
)
return marimo_server.build()
except Exception as e:
logger.error(f"Error creating marimo ASGI app: {e}")
return None
async def setup_hot_reload(app: FastAPI):
"""Setup hot reload for development"""
if not DEV_MODE:
app.state.hot_reload = None
return None
import arel
hot_reload = arel.HotReload(
paths=[
arel.Path("./public"),
arel.Path("./src/templates"),
]
)
app.add_websocket_route("/hot-reload", hot_reload, name="hot-reload")
await hot_reload.startup()
# Store in app state so templates can access it
app.state.hot_reload = hot_reload
logger.info("π₯ Hot reload enabled for development")
return hot_reload
async def cleanup_hot_reload(app: FastAPI):
"""Cleanup hot reload on shutdown"""
hot_reload = getattr(app.state, 'hot_reload', None)
if DEV_MODE and hot_reload:
try:
await hot_reload.shutdown()
logger.info("π₯ Hot reload stopped")
except Exception as e:
logger.error(f"Error stopping hot reload: {e}")
finally:
app.state.hot_reload = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan management"""
try:
await load_search_data(app)
# # Setup marimo (in background to avoid blocking startup)
# asyncio.create_task(setup_marimo(app))
# Setup hot reload for development
await setup_hot_reload(app)
yield
finally:
# Cleanup
await cleanup_hot_reload(app)
def get_hot_reload(app: FastAPI):
"""Get the hot reload instance from app state"""
return getattr(app.state, 'hot_reload', None) |