File size: 7,369 Bytes
6ca4b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import io
import os
import asyncio
import functools
import logging
import polars as pl
import numpy as np
from functools import lru_cache
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pathlib import Path
from cryptography.fernet import Fernet

from src.config import config
from src.utils.logging import context_logger

logger = logging.getLogger(__name__)

DEV_MODE = bool(os.getenv("DEV_MODE", "false").lower())
DATA_CACHE_DIR = Path("cache/processed_data")

def _load_from_cache():

    try:
        DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        
        meta_path = DATA_CACHE_DIR / "meta.parquet"
        item_path = DATA_CACHE_DIR / "item_centroids.npy"
        scale_path = DATA_CACHE_DIR / "scale_centroids.npy"
        
        if not all(p.exists() for p in [meta_path, item_path, scale_path]):
            logger.info("Cache not found")
            return None
        
        return {
            'meta': pl.read_parquet(meta_path),
            'item_centroids': np.load(item_path),
            'scale_centroids': np.load(scale_path),
        }
    except Exception as e:
        logger.warning(f"Failed to load from cache: {e}")
        return None

def _save_to_cache(data):

    try:
        DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
        
        data['meta'].write_parquet(DATA_CACHE_DIR / "meta.parquet")
        np.save(DATA_CACHE_DIR / "item_centroids.npy", data['item_centroids'])
        np.save(DATA_CACHE_DIR / "scale_centroids.npy", data['scale_centroids'])
    except Exception as e:
        logger.warning(f"Failed to save to cache: {e}")


async def load_search_data(app: FastAPI):

    with context_logger(f"πŸ’Ύ Loading search database"):

        encryption_key = config.data.encryption_key
        if not encryption_key:
            logger.error("DATA_ENCRYPTION_KEY not found")
            app.state.search_data = None
            return

        try:
            loop = asyncio.get_event_loop()

            if DEV_MODE:
                cached_data = await loop.run_in_executor(None, _load_from_cache)
                if cached_data is not None:
                    app.state.data = cached_data
                    logger.info("βœ… Loaded from cache (dev mode)")
                    logger.info(f"`data['meta']` shape: {app.state.data['meta'].shape}")
                    logger.info(f"`data['item_centroids']` shape: {app.state.data['item_centroids'].shape}")
                    logger.info(f"`data['scale_centroids']` shape: {app.state.data['scale_centroids'].shape}")
                    return

            df = await loop.run_in_executor(None, _load_dataset_sync, config.data.dataset_path, config.data.encryption_key)

            app.state.data = {
                'meta': df.drop('item_centroid', 'scale_centroid'),
                'item_centroids': np.vstack(df['item_centroid'].to_list()),
                'scale_centroids': np.vstack(df['scale_centroid'].to_list()),
            }

            if DEV_MODE:
                await loop.run_in_executor(None, _save_to_cache, app.state.data)
                logger.info("βœ… Saved to cache (dev mode)")

            logger.info(f"`data['meta']` shape:{ app.state.data['meta'].shape }")
            logger.info(f"`data['item_centroids']` shape:{ app.state.data['item_centroids'].shape }")
            logger.info(f"`data['scale_centroids']` shape:{ app.state.data['scale_centroids'].shape }")

        except Exception as e:
            logger.error(f"Error loading search data: {e}")
            app.state.search_data = None

def _load_dataset_sync(dataset_path: str, encryption_key: str) -> pl.DataFrame:

    import pickle
    import pandas as pd
    from datasets import load_dataset

    dataset = load_dataset(dataset_path, split="train")

    cipher = Fernet(encryption_key)
    decrypted_rows = []
    for row in dataset:
        decrypted_row = {
            col: pickle.loads(cipher.decrypt(row[col]))
            for col in row.keys()
        }
        decrypted_rows.append(decrypted_row)

    df = pd.DataFrame(decrypted_rows)
    return pl.from_pandas(df)

# @lru_cache()
async def setup_marimo(app: FastAPI):
    """Setup Marimo ASGI app during startup"""
    try:
        logger.info("πŸ”„ Loading Marimo ASGI app...")
        
        import marimo as mo
        
        notebook_path = Path("notebooks/marimo-demo.py")
        
        if not notebook_path.exists():
            logger.warning(f"Notebook not found at {notebook_path}")
            app.state.marimo_app = None
            return
            
        # Run marimo setup in executor to avoid blocking
        loop = asyncio.get_event_loop()
        marimo_app = await loop.run_in_executor(
            None,
            lambda: setup_marimo_sync(notebook_path)
        )
        
        if marimo_app:
            app.mount("/marimo", marimo_app)
            app.state.marimo_app = marimo_app
            logger.info("βœ… Marimo ASGI app mounted at /marimo")
        else:
            app.state.marimo_app = None
            logger.error("❌ Failed to mount Marimo ASGI app")
            
    except Exception as e:
        logger.error(f"❌ Error setting up marimo: {e}")
        app.state.marimo_app = None

# @lru_cache()
def setup_marimo_sync(notebook_path):
    """Synchronous marimo setup function"""
    import marimo as mo
    
    try:
        marimo_server = mo.create_asgi_app()
        marimo_server = marimo_server.with_app(
            path="/demo",
            root=str(notebook_path.absolute()),
        )
        return marimo_server.build()
    except Exception as e:
        logger.error(f"Error creating marimo ASGI app: {e}")
        return None

async def setup_hot_reload(app: FastAPI):
    """Setup hot reload for development"""
    if not DEV_MODE:
        app.state.hot_reload = None
        return None
        
    import arel
    
    hot_reload = arel.HotReload(
        paths=[
            arel.Path("./public"),
            arel.Path("./src/templates"),
        ]
    )
    app.add_websocket_route("/hot-reload", hot_reload, name="hot-reload")
    await hot_reload.startup()
    
    # Store in app state so templates can access it
    app.state.hot_reload = hot_reload
    
    logger.info("πŸ”₯ Hot reload enabled for development")
    return hot_reload

async def cleanup_hot_reload(app: FastAPI):
    """Cleanup hot reload on shutdown"""
    hot_reload = getattr(app.state, 'hot_reload', None)
    if DEV_MODE and hot_reload:
        try:
            await hot_reload.shutdown()
            logger.info("πŸ”₯ Hot reload stopped")
        except Exception as e:
            logger.error(f"Error stopping hot reload: {e}")
        finally:
            app.state.hot_reload = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Application lifespan management"""
    
    try:

        await load_search_data(app)
        
        # # Setup marimo (in background to avoid blocking startup)
        # asyncio.create_task(setup_marimo(app))
        
        # Setup hot reload for development
        await setup_hot_reload(app)
        
        yield
        
    finally:
        # Cleanup
        await cleanup_hot_reload(app)

def get_hot_reload(app: FastAPI):
    """Get the hot reload instance from app state"""
    return getattr(app.state, 'hot_reload', None)