|
|
import io |
|
|
import os |
|
|
import asyncio |
|
|
import functools |
|
|
import logging |
|
|
import polars as pl |
|
|
import numpy as np |
|
|
from functools import lru_cache |
|
|
from pathlib import Path |
|
|
from contextlib import asynccontextmanager |
|
|
from fastapi import FastAPI |
|
|
from pathlib import Path |
|
|
from cryptography.fernet import Fernet |
|
|
|
|
|
from src.config import config |
|
|
from src.utils.logging import context_logger |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEV_MODE = bool(os.getenv("DEV_MODE", "false").lower()) |
|
|
DATA_CACHE_DIR = Path("cache/processed_data") |
|
|
|
|
|
def _load_from_cache(): |
|
|
|
|
|
try: |
|
|
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
meta_path = DATA_CACHE_DIR / "meta.parquet" |
|
|
item_path = DATA_CACHE_DIR / "item_centroids.npy" |
|
|
scale_path = DATA_CACHE_DIR / "scale_centroids.npy" |
|
|
|
|
|
if not all(p.exists() for p in [meta_path, item_path, scale_path]): |
|
|
logger.info("Cache not found") |
|
|
return None |
|
|
|
|
|
return { |
|
|
'meta': pl.read_parquet(meta_path), |
|
|
'item_centroids': np.load(item_path), |
|
|
'scale_centroids': np.load(scale_path), |
|
|
} |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load from cache: {e}") |
|
|
return None |
|
|
|
|
|
def _save_to_cache(data): |
|
|
|
|
|
try: |
|
|
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
data['meta'].write_parquet(DATA_CACHE_DIR / "meta.parquet") |
|
|
np.save(DATA_CACHE_DIR / "item_centroids.npy", data['item_centroids']) |
|
|
np.save(DATA_CACHE_DIR / "scale_centroids.npy", data['scale_centroids']) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to save to cache: {e}") |
|
|
|
|
|
|
|
|
async def load_search_data(app: FastAPI): |
|
|
|
|
|
with context_logger(f"πΎ Loading search database"): |
|
|
|
|
|
encryption_key = config.data.encryption_key |
|
|
if not encryption_key: |
|
|
logger.error("DATA_ENCRYPTION_KEY not found") |
|
|
app.state.search_data = None |
|
|
return |
|
|
|
|
|
try: |
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
if DEV_MODE: |
|
|
cached_data = await loop.run_in_executor(None, _load_from_cache) |
|
|
if cached_data is not None: |
|
|
app.state.data = cached_data |
|
|
logger.info("β
Loaded from cache (dev mode)") |
|
|
logger.info(f"`data['meta']` shape: {app.state.data['meta'].shape}") |
|
|
logger.info(f"`data['item_centroids']` shape: {app.state.data['item_centroids'].shape}") |
|
|
logger.info(f"`data['scale_centroids']` shape: {app.state.data['scale_centroids'].shape}") |
|
|
return |
|
|
|
|
|
df = await loop.run_in_executor(None, _load_dataset_sync, config.data.dataset_path, config.data.encryption_key) |
|
|
|
|
|
app.state.data = { |
|
|
'meta': df.drop('item_centroid', 'scale_centroid'), |
|
|
'item_centroids': np.vstack(df['item_centroid'].to_list()), |
|
|
'scale_centroids': np.vstack(df['scale_centroid'].to_list()), |
|
|
} |
|
|
|
|
|
if DEV_MODE: |
|
|
await loop.run_in_executor(None, _save_to_cache, app.state.data) |
|
|
logger.info("β
Saved to cache (dev mode)") |
|
|
|
|
|
logger.info(f"`data['meta']` shape:{ app.state.data['meta'].shape }") |
|
|
logger.info(f"`data['item_centroids']` shape:{ app.state.data['item_centroids'].shape }") |
|
|
logger.info(f"`data['scale_centroids']` shape:{ app.state.data['scale_centroids'].shape }") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading search data: {e}") |
|
|
app.state.search_data = None |
|
|
|
|
|
def _load_dataset_sync(dataset_path: str, encryption_key: str) -> pl.DataFrame: |
|
|
|
|
|
import pickle |
|
|
import pandas as pd |
|
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset(dataset_path, split="train") |
|
|
|
|
|
cipher = Fernet(encryption_key) |
|
|
decrypted_rows = [] |
|
|
for row in dataset: |
|
|
decrypted_row = { |
|
|
col: pickle.loads(cipher.decrypt(row[col])) |
|
|
for col in row.keys() |
|
|
} |
|
|
decrypted_rows.append(decrypted_row) |
|
|
|
|
|
df = pd.DataFrame(decrypted_rows) |
|
|
return pl.from_pandas(df) |
|
|
|
|
|
|
|
|
async def setup_marimo(app: FastAPI): |
|
|
"""Setup Marimo ASGI app during startup""" |
|
|
try: |
|
|
logger.info("π Loading Marimo ASGI app...") |
|
|
|
|
|
import marimo as mo |
|
|
|
|
|
notebook_path = Path("notebooks/marimo-demo.py") |
|
|
|
|
|
if not notebook_path.exists(): |
|
|
logger.warning(f"Notebook not found at {notebook_path}") |
|
|
app.state.marimo_app = None |
|
|
return |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
marimo_app = await loop.run_in_executor( |
|
|
None, |
|
|
lambda: setup_marimo_sync(notebook_path) |
|
|
) |
|
|
|
|
|
if marimo_app: |
|
|
app.mount("/marimo", marimo_app) |
|
|
app.state.marimo_app = marimo_app |
|
|
logger.info("β
Marimo ASGI app mounted at /marimo") |
|
|
else: |
|
|
app.state.marimo_app = None |
|
|
logger.error("β Failed to mount Marimo ASGI app") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Error setting up marimo: {e}") |
|
|
app.state.marimo_app = None |
|
|
|
|
|
|
|
|
def setup_marimo_sync(notebook_path): |
|
|
"""Synchronous marimo setup function""" |
|
|
import marimo as mo |
|
|
|
|
|
try: |
|
|
marimo_server = mo.create_asgi_app() |
|
|
marimo_server = marimo_server.with_app( |
|
|
path="/demo", |
|
|
root=str(notebook_path.absolute()), |
|
|
) |
|
|
return marimo_server.build() |
|
|
except Exception as e: |
|
|
logger.error(f"Error creating marimo ASGI app: {e}") |
|
|
return None |
|
|
|
|
|
async def setup_hot_reload(app: FastAPI): |
|
|
"""Setup hot reload for development""" |
|
|
if not DEV_MODE: |
|
|
app.state.hot_reload = None |
|
|
return None |
|
|
|
|
|
import arel |
|
|
|
|
|
hot_reload = arel.HotReload( |
|
|
paths=[ |
|
|
arel.Path("./public"), |
|
|
arel.Path("./src/templates"), |
|
|
] |
|
|
) |
|
|
app.add_websocket_route("/hot-reload", hot_reload, name="hot-reload") |
|
|
await hot_reload.startup() |
|
|
|
|
|
|
|
|
app.state.hot_reload = hot_reload |
|
|
|
|
|
logger.info("π₯ Hot reload enabled for development") |
|
|
return hot_reload |
|
|
|
|
|
async def cleanup_hot_reload(app: FastAPI): |
|
|
"""Cleanup hot reload on shutdown""" |
|
|
hot_reload = getattr(app.state, 'hot_reload', None) |
|
|
if DEV_MODE and hot_reload: |
|
|
try: |
|
|
await hot_reload.shutdown() |
|
|
logger.info("π₯ Hot reload stopped") |
|
|
except Exception as e: |
|
|
logger.error(f"Error stopping hot reload: {e}") |
|
|
finally: |
|
|
app.state.hot_reload = None |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Application lifespan management""" |
|
|
|
|
|
try: |
|
|
|
|
|
await load_search_data(app) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await setup_hot_reload(app) |
|
|
|
|
|
yield |
|
|
|
|
|
finally: |
|
|
|
|
|
await cleanup_hot_reload(app) |
|
|
|
|
|
def get_hot_reload(app: FastAPI): |
|
|
"""Get the hot reload instance from app state""" |
|
|
return getattr(app.state, 'hot_reload', None) |