synth-net / src /lifespan.py
github-actions
Sync from GitHub (CI)
6ca4b94
import io
import os
import asyncio
import functools
import logging
import polars as pl
import numpy as np
from functools import lru_cache
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pathlib import Path
from cryptography.fernet import Fernet
from src.config import config
from src.utils.logging import context_logger
logger = logging.getLogger(__name__)
DEV_MODE = bool(os.getenv("DEV_MODE", "false").lower())
DATA_CACHE_DIR = Path("cache/processed_data")
def _load_from_cache():
try:
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
meta_path = DATA_CACHE_DIR / "meta.parquet"
item_path = DATA_CACHE_DIR / "item_centroids.npy"
scale_path = DATA_CACHE_DIR / "scale_centroids.npy"
if not all(p.exists() for p in [meta_path, item_path, scale_path]):
logger.info("Cache not found")
return None
return {
'meta': pl.read_parquet(meta_path),
'item_centroids': np.load(item_path),
'scale_centroids': np.load(scale_path),
}
except Exception as e:
logger.warning(f"Failed to load from cache: {e}")
return None
def _save_to_cache(data):
try:
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
data['meta'].write_parquet(DATA_CACHE_DIR / "meta.parquet")
np.save(DATA_CACHE_DIR / "item_centroids.npy", data['item_centroids'])
np.save(DATA_CACHE_DIR / "scale_centroids.npy", data['scale_centroids'])
except Exception as e:
logger.warning(f"Failed to save to cache: {e}")
async def load_search_data(app: FastAPI):
with context_logger(f"πŸ’Ύ Loading search database"):
encryption_key = config.data.encryption_key
if not encryption_key:
logger.error("DATA_ENCRYPTION_KEY not found")
app.state.search_data = None
return
try:
loop = asyncio.get_event_loop()
if DEV_MODE:
cached_data = await loop.run_in_executor(None, _load_from_cache)
if cached_data is not None:
app.state.data = cached_data
logger.info("βœ… Loaded from cache (dev mode)")
logger.info(f"`data['meta']` shape: {app.state.data['meta'].shape}")
logger.info(f"`data['item_centroids']` shape: {app.state.data['item_centroids'].shape}")
logger.info(f"`data['scale_centroids']` shape: {app.state.data['scale_centroids'].shape}")
return
df = await loop.run_in_executor(None, _load_dataset_sync, config.data.dataset_path, config.data.encryption_key)
app.state.data = {
'meta': df.drop('item_centroid', 'scale_centroid'),
'item_centroids': np.vstack(df['item_centroid'].to_list()),
'scale_centroids': np.vstack(df['scale_centroid'].to_list()),
}
if DEV_MODE:
await loop.run_in_executor(None, _save_to_cache, app.state.data)
logger.info("βœ… Saved to cache (dev mode)")
logger.info(f"`data['meta']` shape:{ app.state.data['meta'].shape }")
logger.info(f"`data['item_centroids']` shape:{ app.state.data['item_centroids'].shape }")
logger.info(f"`data['scale_centroids']` shape:{ app.state.data['scale_centroids'].shape }")
except Exception as e:
logger.error(f"Error loading search data: {e}")
app.state.search_data = None
def _load_dataset_sync(dataset_path: str, encryption_key: str) -> pl.DataFrame:
import pickle
import pandas as pd
from datasets import load_dataset
dataset = load_dataset(dataset_path, split="train")
cipher = Fernet(encryption_key)
decrypted_rows = []
for row in dataset:
decrypted_row = {
col: pickle.loads(cipher.decrypt(row[col]))
for col in row.keys()
}
decrypted_rows.append(decrypted_row)
df = pd.DataFrame(decrypted_rows)
return pl.from_pandas(df)
# @lru_cache()
async def setup_marimo(app: FastAPI):
"""Setup Marimo ASGI app during startup"""
try:
logger.info("πŸ”„ Loading Marimo ASGI app...")
import marimo as mo
notebook_path = Path("notebooks/marimo-demo.py")
if not notebook_path.exists():
logger.warning(f"Notebook not found at {notebook_path}")
app.state.marimo_app = None
return
# Run marimo setup in executor to avoid blocking
loop = asyncio.get_event_loop()
marimo_app = await loop.run_in_executor(
None,
lambda: setup_marimo_sync(notebook_path)
)
if marimo_app:
app.mount("/marimo", marimo_app)
app.state.marimo_app = marimo_app
logger.info("βœ… Marimo ASGI app mounted at /marimo")
else:
app.state.marimo_app = None
logger.error("❌ Failed to mount Marimo ASGI app")
except Exception as e:
logger.error(f"❌ Error setting up marimo: {e}")
app.state.marimo_app = None
# @lru_cache()
def setup_marimo_sync(notebook_path):
"""Synchronous marimo setup function"""
import marimo as mo
try:
marimo_server = mo.create_asgi_app()
marimo_server = marimo_server.with_app(
path="/demo",
root=str(notebook_path.absolute()),
)
return marimo_server.build()
except Exception as e:
logger.error(f"Error creating marimo ASGI app: {e}")
return None
async def setup_hot_reload(app: FastAPI):
"""Setup hot reload for development"""
if not DEV_MODE:
app.state.hot_reload = None
return None
import arel
hot_reload = arel.HotReload(
paths=[
arel.Path("./public"),
arel.Path("./src/templates"),
]
)
app.add_websocket_route("/hot-reload", hot_reload, name="hot-reload")
await hot_reload.startup()
# Store in app state so templates can access it
app.state.hot_reload = hot_reload
logger.info("πŸ”₯ Hot reload enabled for development")
return hot_reload
async def cleanup_hot_reload(app: FastAPI):
"""Cleanup hot reload on shutdown"""
hot_reload = getattr(app.state, 'hot_reload', None)
if DEV_MODE and hot_reload:
try:
await hot_reload.shutdown()
logger.info("πŸ”₯ Hot reload stopped")
except Exception as e:
logger.error(f"Error stopping hot reload: {e}")
finally:
app.state.hot_reload = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan management"""
try:
await load_search_data(app)
# # Setup marimo (in background to avoid blocking startup)
# asyncio.create_task(setup_marimo(app))
# Setup hot reload for development
await setup_hot_reload(app)
yield
finally:
# Cleanup
await cleanup_hot_reload(app)
def get_hot_reload(app: FastAPI):
"""Get the hot reload instance from app state"""
return getattr(app.state, 'hot_reload', None)