import io import os import asyncio import functools import logging import polars as pl import numpy as np from functools import lru_cache from pathlib import Path from contextlib import asynccontextmanager from fastapi import FastAPI from pathlib import Path from cryptography.fernet import Fernet from src.config import config from src.utils.logging import context_logger logger = logging.getLogger(__name__) DEV_MODE = bool(os.getenv("DEV_MODE", "false").lower()) DATA_CACHE_DIR = Path("cache/processed_data") def _load_from_cache(): try: DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True) meta_path = DATA_CACHE_DIR / "meta.parquet" item_path = DATA_CACHE_DIR / "item_centroids.npy" scale_path = DATA_CACHE_DIR / "scale_centroids.npy" if not all(p.exists() for p in [meta_path, item_path, scale_path]): logger.info("Cache not found") return None return { 'meta': pl.read_parquet(meta_path), 'item_centroids': np.load(item_path), 'scale_centroids': np.load(scale_path), } except Exception as e: logger.warning(f"Failed to load from cache: {e}") return None def _save_to_cache(data): try: DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True) data['meta'].write_parquet(DATA_CACHE_DIR / "meta.parquet") np.save(DATA_CACHE_DIR / "item_centroids.npy", data['item_centroids']) np.save(DATA_CACHE_DIR / "scale_centroids.npy", data['scale_centroids']) except Exception as e: logger.warning(f"Failed to save to cache: {e}") async def load_search_data(app: FastAPI): with context_logger(f"💾 Loading search database"): encryption_key = config.data.encryption_key if not encryption_key: logger.error("DATA_ENCRYPTION_KEY not found") app.state.search_data = None return try: loop = asyncio.get_event_loop() if DEV_MODE: cached_data = await loop.run_in_executor(None, _load_from_cache) if cached_data is not None: app.state.data = cached_data logger.info("✅ Loaded from cache (dev mode)") logger.info(f"`data['meta']` shape: {app.state.data['meta'].shape}") logger.info(f"`data['item_centroids']` shape: {app.state.data['item_centroids'].shape}") logger.info(f"`data['scale_centroids']` shape: {app.state.data['scale_centroids'].shape}") return df = await loop.run_in_executor(None, _load_dataset_sync, config.data.dataset_path, config.data.encryption_key) app.state.data = { 'meta': df.drop('item_centroid', 'scale_centroid'), 'item_centroids': np.vstack(df['item_centroid'].to_list()), 'scale_centroids': np.vstack(df['scale_centroid'].to_list()), } if DEV_MODE: await loop.run_in_executor(None, _save_to_cache, app.state.data) logger.info("✅ Saved to cache (dev mode)") logger.info(f"`data['meta']` shape:{ app.state.data['meta'].shape }") logger.info(f"`data['item_centroids']` shape:{ app.state.data['item_centroids'].shape }") logger.info(f"`data['scale_centroids']` shape:{ app.state.data['scale_centroids'].shape }") except Exception as e: logger.error(f"Error loading search data: {e}") app.state.search_data = None def _load_dataset_sync(dataset_path: str, encryption_key: str) -> pl.DataFrame: import pickle import pandas as pd from datasets import load_dataset dataset = load_dataset(dataset_path, split="train") cipher = Fernet(encryption_key) decrypted_rows = [] for row in dataset: decrypted_row = { col: pickle.loads(cipher.decrypt(row[col])) for col in row.keys() } decrypted_rows.append(decrypted_row) df = pd.DataFrame(decrypted_rows) return pl.from_pandas(df) # @lru_cache() async def setup_marimo(app: FastAPI): """Setup Marimo ASGI app during startup""" try: logger.info("🔄 Loading Marimo ASGI app...") import marimo as mo notebook_path = Path("notebooks/marimo-demo.py") if not notebook_path.exists(): logger.warning(f"Notebook not found at {notebook_path}") app.state.marimo_app = None return # Run marimo setup in executor to avoid blocking loop = asyncio.get_event_loop() marimo_app = await loop.run_in_executor( None, lambda: setup_marimo_sync(notebook_path) ) if marimo_app: app.mount("/marimo", marimo_app) app.state.marimo_app = marimo_app logger.info("✅ Marimo ASGI app mounted at /marimo") else: app.state.marimo_app = None logger.error("❌ Failed to mount Marimo ASGI app") except Exception as e: logger.error(f"❌ Error setting up marimo: {e}") app.state.marimo_app = None # @lru_cache() def setup_marimo_sync(notebook_path): """Synchronous marimo setup function""" import marimo as mo try: marimo_server = mo.create_asgi_app() marimo_server = marimo_server.with_app( path="/demo", root=str(notebook_path.absolute()), ) return marimo_server.build() except Exception as e: logger.error(f"Error creating marimo ASGI app: {e}") return None async def setup_hot_reload(app: FastAPI): """Setup hot reload for development""" if not DEV_MODE: app.state.hot_reload = None return None import arel hot_reload = arel.HotReload( paths=[ arel.Path("./public"), arel.Path("./src/templates"), ] ) app.add_websocket_route("/hot-reload", hot_reload, name="hot-reload") await hot_reload.startup() # Store in app state so templates can access it app.state.hot_reload = hot_reload logger.info("🔥 Hot reload enabled for development") return hot_reload async def cleanup_hot_reload(app: FastAPI): """Cleanup hot reload on shutdown""" hot_reload = getattr(app.state, 'hot_reload', None) if DEV_MODE and hot_reload: try: await hot_reload.shutdown() logger.info("🔥 Hot reload stopped") except Exception as e: logger.error(f"Error stopping hot reload: {e}") finally: app.state.hot_reload = None @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan management""" try: await load_search_data(app) # # Setup marimo (in background to avoid blocking startup) # asyncio.create_task(setup_marimo(app)) # Setup hot reload for development await setup_hot_reload(app) yield finally: # Cleanup await cleanup_hot_reload(app) def get_hot_reload(app: FastAPI): """Get the hot reload instance from app state""" return getattr(app.state, 'hot_reload', None)