AIDA / app /core /rate_limiter.py
destinyebuka's picture
dora
4c9881b
# ============================================================
# app/core/rate_limiter.py - Advanced Token Bucket Rate Limiting
# ============================================================
import logging
import time
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from app.ai.config import redis_client
from app.core.error_handling import LojizError
logger = logging.getLogger(__name__)
# ============================================================
# Rate Limit Configuration
# ============================================================
class RateLimitConfig:
"""Rate limiting configuration by operation type"""
# Operation costs (in "credits")
OPERATION_COSTS = {
"chat": 1, # Basic chat
"search": 2, # Vector search (expensive)
"list": 3, # Create listing (ML validation)
"publish": 5, # Publish (database + indexing)
"edit": 2, # Edit listing
"upload_image": 4, # Image upload (Cloudflare)
}
# Rate limits (credits per time window)
LIMITS = {
"user": {
"credits": 100, # 100 credits per minute
"window_seconds": 60,
},
"ip": {
"credits": 500, # 500 credits per minute (more permissive)
"window_seconds": 60,
},
"global": {
"credits": 10000, # 10k credits per minute (system-wide)
"window_seconds": 60,
}
}
# Burst allowance (temporary spike tolerance)
BURST_MULTIPLIER = 1.5 # Allow 50% burst above limit
# Cleanup settings
CLEANUP_INTERVAL = 3600 # Clean old buckets every hour
MAX_BUCKET_AGE = 86400 # Keep buckets for 24 hours max
# ============================================================
# Token Bucket Implementation
# ============================================================
class TokenBucket:
"""Token bucket for rate limiting"""
def __init__(self, capacity: int, refill_rate: float):
"""
Args:
capacity: Max tokens in bucket
refill_rate: Tokens per second
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
def refill(self) -> None:
"""Refill tokens based on time elapsed"""
now = time.time()
elapsed = now - self.last_refill
new_tokens = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + new_tokens)
self.last_refill = now
def consume(self, tokens: int) -> bool:
"""Try to consume tokens"""
self.refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def get_available(self) -> int:
"""Get available tokens"""
self.refill()
return int(self.tokens)
# ============================================================
# Advanced Rate Limiter
# ============================================================
class AdvancedRateLimiter:
"""Token bucket rate limiter with multiple scopes"""
def __init__(self):
self.buckets: Dict[str, TokenBucket] = {}
self.last_cleanup = time.time()
async def is_allowed(
self,
user_id: str,
operation: str,
ip_address: str = None,
) -> Tuple[bool, Dict[str, any]]:
"""
Check if operation is allowed for user
Returns:
(is_allowed, rate_limit_info)
"""
operation_cost = RateLimitConfig.OPERATION_COSTS.get(operation, 1)
# Check all scopes
user_check = await self._check_scope(
f"user:{user_id}",
operation_cost,
RateLimitConfig.LIMITS["user"]
)
ip_check = await self._check_scope(
f"ip:{ip_address}",
operation_cost,
RateLimitConfig.LIMITS["ip"]
) if ip_address else (True, {})
global_check = await self._check_scope(
"global",
operation_cost,
RateLimitConfig.LIMITS["global"]
)
# All must pass
is_allowed = user_check[0] and ip_check[0] and global_check[0]
info = {
"allowed": is_allowed,
"operation": operation,
"cost": operation_cost,
"user": user_check[1],
"ip": ip_check[1] if ip_address else None,
"global": global_check[1],
"timestamp": datetime.utcnow().isoformat(),
}
if not is_allowed:
logger.warning(
f"⚠️ Rate limit exceeded",
extra={
"user_id": user_id,
"operation": operation,
"ip": ip_address,
}
)
return is_allowed, info
async def _check_scope(
self,
scope_key: str,
cost: int,
config: Dict,
) -> Tuple[bool, Dict]:
"""Check single scope (user/ip/global)"""
try:
# Get bucket from Redis
bucket_data = await redis_client.get(f"rate_limit:{scope_key}")
if bucket_data:
# Deserialize
import json
data = json.loads(bucket_data)
tokens = data["tokens"]
last_refill = data["last_refill"]
else:
# New bucket
tokens = config["credits"]
last_refill = time.time()
# Refill based on time elapsed
now = time.time()
elapsed = now - last_refill
refill_rate = config["credits"] / config["window_seconds"]
new_tokens = elapsed * refill_rate
tokens = min(config["credits"], tokens + new_tokens)
# Check if allowed
allowed = tokens >= cost
if allowed:
tokens -= cost
logger.debug(f"βœ… Rate limit OK: {scope_key} ({int(tokens)} tokens left)")
else:
logger.warning(f"🚫 Rate limit exceeded: {scope_key}")
# Save back to Redis
import json
await redis_client.setex(
f"rate_limit:{scope_key}",
config["window_seconds"] * 2, # TTL
json.dumps({
"tokens": tokens,
"last_refill": now,
"capacity": config["credits"],
})
)
return allowed, {
"remaining": int(tokens),
"capacity": config["credits"],
"reset_in": config["window_seconds"],
}
except Exception as e:
logger.error(f"❌ Rate limit check error: {e}")
# Fail open (allow) on error
return True, {"error": "rate_limit_check_failed"}
async def get_usage_stats(self, user_id: str) -> Dict:
"""Get current usage stats for user"""
bucket_data = await redis_client.get(f"rate_limit:user:{user_id}")
if not bucket_data:
return {
"user_id": user_id,
"remaining": RateLimitConfig.LIMITS["user"]["credits"],
"capacity": RateLimitConfig.LIMITS["user"]["credits"],
"reset_in": RateLimitConfig.LIMITS["user"]["window_seconds"],
}
import json
data = json.loads(bucket_data)
return {
"user_id": user_id,
"remaining": int(data["tokens"]),
"capacity": data["capacity"],
"reset_in": RateLimitConfig.LIMITS["user"]["window_seconds"],
}
async def reset_user_limits(self, user_id: str) -> bool:
"""Reset rate limits for user (admin only)"""
try:
await redis_client.delete(f"rate_limit:user:{user_id}")
logger.info(f"βœ… Rate limits reset for user: {user_id}")
return True
except Exception as e:
logger.error(f"❌ Failed to reset limits: {e}")
return False
# ============================================================
# Global Instance
# ============================================================
_rate_limiter = None
def get_rate_limiter() -> AdvancedRateLimiter:
"""Get or create rate limiter instance"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = AdvancedRateLimiter()
return _rate_limiter
# ============================================================
# Exceptions
# ============================================================
class RateLimitExceeded(LojizError):
"""Rate limit exceeded error"""
def __init__(self, retry_after: int = 60):
self.retry_after = retry_after
super().__init__(
f"Rate limit exceeded. Try again in {retry_after}s",
error_code="RATE_LIMIT_EXCEEDED",
status_code=429,
recoverable=True,
context={"retry_after": retry_after}
)