File size: 9,381 Bytes
4c9881b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# ============================================================
# app/core/rate_limiter.py - Advanced Token Bucket Rate Limiting
# ============================================================

import logging
import time
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from app.ai.config import redis_client
from app.core.error_handling import LojizError

logger = logging.getLogger(__name__)

# ============================================================
# Rate Limit Configuration
# ============================================================

class RateLimitConfig:
    """Rate limiting configuration by operation type"""
    
    # Operation costs (in "credits")
    OPERATION_COSTS = {
        "chat": 1,              # Basic chat
        "search": 2,            # Vector search (expensive)
        "list": 3,              # Create listing (ML validation)
        "publish": 5,           # Publish (database + indexing)
        "edit": 2,              # Edit listing
        "upload_image": 4,      # Image upload (Cloudflare)
    }
    
    # Rate limits (credits per time window)
    LIMITS = {
        "user": {
            "credits": 100,     # 100 credits per minute
            "window_seconds": 60,
        },
        "ip": {
            "credits": 500,     # 500 credits per minute (more permissive)
            "window_seconds": 60,
        },
        "global": {
            "credits": 10000,   # 10k credits per minute (system-wide)
            "window_seconds": 60,
        }
    }
    
    # Burst allowance (temporary spike tolerance)
    BURST_MULTIPLIER = 1.5  # Allow 50% burst above limit
    
    # Cleanup settings
    CLEANUP_INTERVAL = 3600  # Clean old buckets every hour
    MAX_BUCKET_AGE = 86400   # Keep buckets for 24 hours max

# ============================================================
# Token Bucket Implementation
# ============================================================

class TokenBucket:
    """Token bucket for rate limiting"""
    
    def __init__(self, capacity: int, refill_rate: float):
        """
        Args:
            capacity: Max tokens in bucket
            refill_rate: Tokens per second
        """
        self.capacity = capacity
        self.refill_rate = refill_rate
        self.tokens = capacity
        self.last_refill = time.time()
    
    def refill(self) -> None:
        """Refill tokens based on time elapsed"""
        now = time.time()
        elapsed = now - self.last_refill
        
        new_tokens = elapsed * self.refill_rate
        self.tokens = min(self.capacity, self.tokens + new_tokens)
        self.last_refill = now
    
    def consume(self, tokens: int) -> bool:
        """Try to consume tokens"""
        self.refill()
        
        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        
        return False
    
    def get_available(self) -> int:
        """Get available tokens"""
        self.refill()
        return int(self.tokens)

# ============================================================
# Advanced Rate Limiter
# ============================================================

class AdvancedRateLimiter:
    """Token bucket rate limiter with multiple scopes"""
    
    def __init__(self):
        self.buckets: Dict[str, TokenBucket] = {}
        self.last_cleanup = time.time()
    
    async def is_allowed(
        self,
        user_id: str,
        operation: str,
        ip_address: str = None,
    ) -> Tuple[bool, Dict[str, any]]:
        """
        Check if operation is allowed for user
        
        Returns:
            (is_allowed, rate_limit_info)
        """
        
        operation_cost = RateLimitConfig.OPERATION_COSTS.get(operation, 1)
        
        # Check all scopes
        user_check = await self._check_scope(
            f"user:{user_id}",
            operation_cost,
            RateLimitConfig.LIMITS["user"]
        )
        
        ip_check = await self._check_scope(
            f"ip:{ip_address}",
            operation_cost,
            RateLimitConfig.LIMITS["ip"]
        ) if ip_address else (True, {})
        
        global_check = await self._check_scope(
            "global",
            operation_cost,
            RateLimitConfig.LIMITS["global"]
        )
        
        # All must pass
        is_allowed = user_check[0] and ip_check[0] and global_check[0]
        
        info = {
            "allowed": is_allowed,
            "operation": operation,
            "cost": operation_cost,
            "user": user_check[1],
            "ip": ip_check[1] if ip_address else None,
            "global": global_check[1],
            "timestamp": datetime.utcnow().isoformat(),
        }
        
        if not is_allowed:
            logger.warning(
                f"⚠️ Rate limit exceeded",
                extra={
                    "user_id": user_id,
                    "operation": operation,
                    "ip": ip_address,
                }
            )
        
        return is_allowed, info
    
    async def _check_scope(
        self,
        scope_key: str,
        cost: int,
        config: Dict,
    ) -> Tuple[bool, Dict]:
        """Check single scope (user/ip/global)"""
        
        try:
            # Get bucket from Redis
            bucket_data = await redis_client.get(f"rate_limit:{scope_key}")
            
            if bucket_data:
                # Deserialize
                import json
                data = json.loads(bucket_data)
                tokens = data["tokens"]
                last_refill = data["last_refill"]
            else:
                # New bucket
                tokens = config["credits"]
                last_refill = time.time()
            
            # Refill based on time elapsed
            now = time.time()
            elapsed = now - last_refill
            refill_rate = config["credits"] / config["window_seconds"]
            new_tokens = elapsed * refill_rate
            tokens = min(config["credits"], tokens + new_tokens)
            
            # Check if allowed
            allowed = tokens >= cost
            
            if allowed:
                tokens -= cost
                logger.debug(f"βœ… Rate limit OK: {scope_key} ({int(tokens)} tokens left)")
            else:
                logger.warning(f"🚫 Rate limit exceeded: {scope_key}")
            
            # Save back to Redis
            import json
            await redis_client.setex(
                f"rate_limit:{scope_key}",
                config["window_seconds"] * 2,  # TTL
                json.dumps({
                    "tokens": tokens,
                    "last_refill": now,
                    "capacity": config["credits"],
                })
            )
            
            return allowed, {
                "remaining": int(tokens),
                "capacity": config["credits"],
                "reset_in": config["window_seconds"],
            }
        
        except Exception as e:
            logger.error(f"❌ Rate limit check error: {e}")
            # Fail open (allow) on error
            return True, {"error": "rate_limit_check_failed"}
    
    async def get_usage_stats(self, user_id: str) -> Dict:
        """Get current usage stats for user"""
        
        bucket_data = await redis_client.get(f"rate_limit:user:{user_id}")
        
        if not bucket_data:
            return {
                "user_id": user_id,
                "remaining": RateLimitConfig.LIMITS["user"]["credits"],
                "capacity": RateLimitConfig.LIMITS["user"]["credits"],
                "reset_in": RateLimitConfig.LIMITS["user"]["window_seconds"],
            }
        
        import json
        data = json.loads(bucket_data)
        
        return {
            "user_id": user_id,
            "remaining": int(data["tokens"]),
            "capacity": data["capacity"],
            "reset_in": RateLimitConfig.LIMITS["user"]["window_seconds"],
        }
    
    async def reset_user_limits(self, user_id: str) -> bool:
        """Reset rate limits for user (admin only)"""
        try:
            await redis_client.delete(f"rate_limit:user:{user_id}")
            logger.info(f"βœ… Rate limits reset for user: {user_id}")
            return True
        except Exception as e:
            logger.error(f"❌ Failed to reset limits: {e}")
            return False

# ============================================================
# Global Instance
# ============================================================

_rate_limiter = None

def get_rate_limiter() -> AdvancedRateLimiter:
    """Get or create rate limiter instance"""
    global _rate_limiter
    if _rate_limiter is None:
        _rate_limiter = AdvancedRateLimiter()
    return _rate_limiter

# ============================================================
# Exceptions
# ============================================================

class RateLimitExceeded(LojizError):
    """Rate limit exceeded error"""
    
    def __init__(self, retry_after: int = 60):
        self.retry_after = retry_after
        super().__init__(
            f"Rate limit exceeded. Try again in {retry_after}s",
            error_code="RATE_LIMIT_EXCEEDED",
            status_code=429,
            recoverable=True,
            context={"retry_after": retry_after}
        )