Spaces:
Sleeping
Sleeping
| """ | |
| MrrrMe Backend - WebSocket Handler (PRODUCTION - FIXED DISCONNECTION HANDLING) | |
| """ | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from starlette.websockets import WebSocketState | |
| import asyncio | |
| import base64 | |
| import numpy as np | |
| import cv2 | |
| import io | |
| from PIL import Image | |
| import requests | |
| # Import the module for dynamic model access | |
| from . import models as models_module | |
| from .session.manager import validate_token, save_message, load_user_history | |
| from .session.summary import generate_session_summary | |
| from .auth.database import get_db_connection | |
| from .utils.helpers import get_avatar_api_url | |
| from .config import GREETINGS | |
| from .processing.speech import process_speech_end | |
| AVATAR_API = get_avatar_api_url() | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """Main WebSocket endpoint handler""" | |
| print("\n" + "="*80, flush=True) | |
| print("[WebSocket] π NEW CONNECTION", flush=True) | |
| print("="*80, flush=True) | |
| try: | |
| await websocket.accept() | |
| print("[WebSocket] β Connection accepted", flush=True) | |
| except Exception as e: | |
| print(f"[WebSocket] β Failed to accept: {e}", flush=True) | |
| return | |
| # Session variables | |
| session_data = None | |
| user_summary = None | |
| session_id = None | |
| user_id = None | |
| username = None | |
| audio_buffer = [] | |
| user_preferences = {"voice": "female", "language": "en", "personality": "therapist"} | |
| # Message counters (for stats, not logging) | |
| message_count = 0 | |
| video_count = 0 | |
| audio_count = 0 | |
| try: | |
| # ===== AUTHENTICATION ===== | |
| print("[WebSocket] β³ Waiting for auth message...", flush=True) | |
| try: | |
| auth_msg = await asyncio.wait_for(websocket.receive_json(), timeout=10.0) | |
| except asyncio.TimeoutError: | |
| print("[WebSocket] β Auth timeout", flush=True) | |
| await websocket.close(code=1008, reason="Auth timeout") | |
| return | |
| if auth_msg.get("type") != "auth": | |
| print(f"[WebSocket] β Wrong type: {auth_msg.get('type')}", flush=True) | |
| await websocket.send_json({"type": "error", "message": "Auth required"}) | |
| await websocket.close(code=1008) | |
| return | |
| token = auth_msg.get("token") | |
| if not token: | |
| print("[WebSocket] β No token", flush=True) | |
| await websocket.send_json({"type": "error", "message": "No token"}) | |
| await websocket.close(code=1008) | |
| return | |
| # Validate | |
| session_data = validate_token(token) | |
| if not session_data: | |
| print("[WebSocket] β Invalid token", flush=True) | |
| await websocket.send_json({"type": "error", "message": "Invalid session"}) | |
| await websocket.close(code=1008) | |
| return | |
| session_id = session_data['session_id'] | |
| user_id = session_data['user_id'] | |
| username = session_data['username'] | |
| print(f"[WebSocket] β Authenticated: {username}", flush=True) | |
| # Get summary | |
| try: | |
| conn = get_db_connection() | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT summary_text FROM user_summaries WHERE user_id = ?", (user_id,)) | |
| summary_row = cursor.fetchone() | |
| user_summary = summary_row[0] if summary_row else None | |
| conn.close() | |
| except Exception as e: | |
| print(f"[WebSocket] β οΈ Summary error: {e}", flush=True) | |
| user_summary = None | |
| # Send auth confirmation | |
| await websocket.send_json({ | |
| "type": "authenticated", | |
| "username": username, | |
| "summary": user_summary | |
| }) | |
| # Load history | |
| current_models = models_module.get_models() | |
| if current_models['llm_generator']: | |
| current_models['llm_generator'].clear_history() | |
| user_history = load_user_history(user_id, limit=10) | |
| for role, content in user_history: | |
| current_models['llm_generator'].conversation_history.append({ | |
| "role": role, | |
| "content": content | |
| }) | |
| if user_history: | |
| print(f"[WebSocket] π Loaded {len(user_history)} messages", flush=True) | |
| # Check models ready | |
| if not models_module.model_state.ready: | |
| print(f"[WebSocket] β³ Waiting for models...", flush=True) | |
| await websocket.send_json({ | |
| "type": "status", | |
| "message": "Models loading..." | |
| }) | |
| for i in range(900): | |
| if models_module.model_state.ready: | |
| await websocket.send_json({ | |
| "type": "status", | |
| "message": "Models ready!" | |
| }) | |
| print(f"[WebSocket] β Models ready after {i}s", flush=True) | |
| break | |
| await asyncio.sleep(1) | |
| if not models_module.model_state.ready: | |
| print("[WebSocket] β Models timeout", flush=True) | |
| await websocket.send_json({"type": "error", "message": "Models timeout"}) | |
| await websocket.close(code=1011) | |
| return | |
| print(f"[WebSocket] β Ready - starting message loop for {username}", flush=True) | |
| # ===== MESSAGE LOOP ===== | |
| while True: | |
| # Check if WebSocket is still connected | |
| if websocket.client_state != WebSocketState.CONNECTED: | |
| print(f"[WebSocket] π {username} disconnected (client closed)", flush=True) | |
| break | |
| try: | |
| data = await websocket.receive_json() | |
| message_count += 1 | |
| msg_type = data.get("type") | |
| # Get latest models dynamically | |
| models = models_module.get_models() | |
| face_processor = models['face_processor'] | |
| voice_worker = models['voice_worker'] | |
| whisper_worker = models['whisper_worker'] | |
| # ===== PREFERENCES ===== | |
| if msg_type == "preferences": | |
| if "voice" in data: | |
| user_preferences["voice"] = data.get("voice", "female") | |
| if "language" in data: | |
| user_preferences["language"] = data.get("language", "en") | |
| if "personality" in data: | |
| user_preferences["personality"] = data.get("personality", "therapist") | |
| print(f"[Preferences] β {username}: {user_preferences}", flush=True) | |
| if "personality" in data and models['llm_generator']: | |
| new_personality = data.get("personality", "therapist") | |
| models['llm_generator'].set_personality(new_personality) | |
| print(f"[LLM] π Personality changed to: {new_personality}", flush=True) | |
| continue | |
| # ===== GREETING ===== | |
| elif msg_type == "request_greeting": | |
| print(f"[Greeting] π€ Request from {username}", flush=True) | |
| try: | |
| lang = user_preferences.get("language", "en") | |
| if user_summary: | |
| greeting_text = GREETINGS[lang]["returning"].format(username=username) | |
| else: | |
| greeting_text = GREETINGS[lang]["new"].format(username=username) | |
| # Try avatar TTS | |
| audio_url = None | |
| visemes = None | |
| try: | |
| voice_pref = user_preferences.get("voice", "female") | |
| lang_pref = user_preferences.get("language", "en") | |
| avatar_response = requests.post( | |
| f"{AVATAR_API}/speak", | |
| data={ | |
| "text": greeting_text, | |
| "voice": voice_pref, | |
| "language": lang_pref | |
| }, | |
| timeout=10 | |
| ) | |
| if avatar_response.status_code == 200: | |
| avatar_data = avatar_response.json() | |
| audio_url = avatar_data.get("audio_url") | |
| visemes = avatar_data.get("visemes") | |
| print(f"[Greeting] β TTS generated", flush=True) | |
| else: | |
| print(f"[Greeting] β οΈ TTS failed: {avatar_response.status_code}", flush=True) | |
| except Exception as tts_err: | |
| print(f"[Greeting] β οΈ TTS error: {tts_err}", flush=True) | |
| response_data = { | |
| "type": "llm_response", | |
| "text": greeting_text, | |
| "emotion": "Neutral", | |
| "intensity": 0.5, | |
| "is_greeting": True | |
| } | |
| if audio_url and visemes: | |
| response_data["audio_url"] = audio_url | |
| response_data["visemes"] = visemes | |
| else: | |
| response_data["text_only"] = True | |
| await websocket.send_json(response_data) | |
| print(f"[Greeting] β Sent to {username}", flush=True) | |
| save_message(session_id, "assistant", greeting_text, "Neutral") | |
| except Exception as err: | |
| print(f"[Greeting] β Error: {err}", flush=True) | |
| # ===== VIDEO FRAME ===== | |
| elif msg_type == "video_frame": | |
| video_count += 1 | |
| try: | |
| img_data = base64.b64decode(data["frame"].split(",")[1]) | |
| img = Image.open(io.BytesIO(img_data)) | |
| frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| processed_frame, result = face_processor.process_frame(frame) | |
| face_emotion = face_processor.get_last_emotion() or "Neutral" | |
| face_confidence = face_processor.get_last_confidence() or 0.0 | |
| face_probs = face_processor.get_last_probs() | |
| face_quality = getattr(face_processor, 'get_last_quality', lambda: 0.5)() | |
| await websocket.send_json({ | |
| "type": "face_emotion", | |
| "emotion": face_emotion, | |
| "confidence": face_confidence, | |
| "probabilities": face_probs.tolist(), | |
| "quality": face_quality | |
| }) | |
| except Exception as e: | |
| if video_count % 100 == 0: | |
| print(f"[Video] β οΈ Error (frame {video_count}): {e}", flush=True) | |
| # ===== AUDIO CHUNK (UPDATED FIX) ===== | |
| elif msg_type == "audio_chunk": | |
| audio_count += 1 | |
| try: | |
| audio_data = base64.b64decode(data["audio"]) | |
| # Validate buffer size | |
| if len(audio_data) % 2 != 0 or len(audio_data) == 0: | |
| continue | |
| # Convert to float32 array | |
| audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 | |
| # 1. Send to Whisper | |
| if whisper_worker: | |
| whisper_worker.add_audio(audio_array) | |
| # 2. β FIX: Send to Voice Worker (VITAL for server mode) | |
| if voice_worker: | |
| voice_worker.add_audio(audio_array) | |
| # Buffer for simple UI updates (throttled) | |
| audio_buffer.append(audio_array) | |
| if len(audio_buffer) >= 5: | |
| if voice_worker: | |
| voice_probs, voice_emotion = voice_worker.get_probs() | |
| await websocket.send_json({ | |
| "type": "voice_emotion", | |
| "emotion": voice_emotion | |
| }) | |
| audio_buffer.clear() | |
| except ValueError as ve: | |
| if audio_count % 50 == 0: | |
| print(f"[Audio] β οΈ Buffer error (chunk {audio_count}): {ve}", flush=True) | |
| except Exception as e: | |
| if audio_count % 50 == 0: | |
| print(f"[Audio] β οΈ Error (chunk {audio_count}): {e}", flush=True) | |
| # ===== SPEECH END ===== | |
| elif msg_type == "speech_end": | |
| transcription = data.get("text", "").strip() | |
| print(f"\n{'='*80}", flush=True) | |
| print(f"[Speech] π€ {username}: '{transcription}'", flush=True) | |
| print(f"{'='*80}", flush=True) | |
| try: | |
| await process_speech_end( | |
| websocket, transcription, session_id, user_id, | |
| username, user_summary, user_preferences | |
| ) | |
| except Exception as e: | |
| print(f"[Speech] β Error: {e}", flush=True) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Error: {str(e)}" | |
| }) | |
| else: | |
| print(f"[WebSocket] β οΈ Unknown type: {msg_type}", flush=True) | |
| except WebSocketDisconnect: | |
| print(f"[WebSocket] π {username} disconnected (WebSocketDisconnect)", flush=True) | |
| break | |
| except Exception as err: | |
| if websocket.client_state != WebSocketState.CONNECTED: | |
| print(f"[WebSocket] π {username} disconnected (error)", flush=True) | |
| break | |
| print(f"[WebSocket] β οΈ Loop error: {err}", flush=True) | |
| await asyncio.sleep(0.1) | |
| except Exception as outer_err: | |
| print(f"[WebSocket] β Fatal: {outer_err}", flush=True) | |
| finally: | |
| print(f"\n[WebSocket] π Session stats for {username or 'Unknown'}:", flush=True) | |
| print(f" - Messages: {message_count}", flush=True) | |
| print(f" - Video frames: {video_count}", flush=True) | |
| print(f" - Audio chunks: {audio_count}", flush=True) | |
| if session_id and user_id: | |
| try: | |
| await generate_session_summary(session_id, user_id) | |
| except Exception as e: | |
| print(f"[Summary] β Error: {e}", flush=True) | |
| print(f"[WebSocket] π Session closed for {username or 'Unknown'}", flush=True) | |
| print("="*80 + "\n", flush=True) |