""" MrrrMe Backend - WebSocket Handler (PRODUCTION - FIXED DISCONNECTION HANDLING) """ from fastapi import WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState import asyncio import base64 import numpy as np import cv2 import io from PIL import Image import requests # Import the module for dynamic model access from . import models as models_module from .session.manager import validate_token, save_message, load_user_history from .session.summary import generate_session_summary from .auth.database import get_db_connection from .utils.helpers import get_avatar_api_url from .config import GREETINGS from .processing.speech import process_speech_end AVATAR_API = get_avatar_api_url() async def websocket_endpoint(websocket: WebSocket): """Main WebSocket endpoint handler""" print("\n" + "="*80, flush=True) print("[WebSocket] 🔌 NEW CONNECTION", flush=True) print("="*80, flush=True) try: await websocket.accept() print("[WebSocket] ✅ Connection accepted", flush=True) except Exception as e: print(f"[WebSocket] ❌ Failed to accept: {e}", flush=True) return # Session variables session_data = None user_summary = None session_id = None user_id = None username = None audio_buffer = [] user_preferences = {"voice": "female", "language": "en", "personality": "therapist"} # Message counters (for stats, not logging) message_count = 0 video_count = 0 audio_count = 0 try: # ===== AUTHENTICATION ===== print("[WebSocket] ⏳ Waiting for auth message...", flush=True) try: auth_msg = await asyncio.wait_for(websocket.receive_json(), timeout=10.0) except asyncio.TimeoutError: print("[WebSocket] ❌ Auth timeout", flush=True) await websocket.close(code=1008, reason="Auth timeout") return if auth_msg.get("type") != "auth": print(f"[WebSocket] ❌ Wrong type: {auth_msg.get('type')}", flush=True) await websocket.send_json({"type": "error", "message": "Auth required"}) await websocket.close(code=1008) return token = auth_msg.get("token") if not token: print("[WebSocket] ❌ No token", flush=True) await websocket.send_json({"type": "error", "message": "No token"}) await websocket.close(code=1008) return # Validate session_data = validate_token(token) if not session_data: print("[WebSocket] ❌ Invalid token", flush=True) await websocket.send_json({"type": "error", "message": "Invalid session"}) await websocket.close(code=1008) return session_id = session_data['session_id'] user_id = session_data['user_id'] username = session_data['username'] print(f"[WebSocket] ✅ Authenticated: {username}", flush=True) # Get summary try: conn = get_db_connection() cursor = conn.cursor() cursor.execute("SELECT summary_text FROM user_summaries WHERE user_id = ?", (user_id,)) summary_row = cursor.fetchone() user_summary = summary_row[0] if summary_row else None conn.close() except Exception as e: print(f"[WebSocket] ⚠️ Summary error: {e}", flush=True) user_summary = None # Send auth confirmation await websocket.send_json({ "type": "authenticated", "username": username, "summary": user_summary }) # Load history current_models = models_module.get_models() if current_models['llm_generator']: current_models['llm_generator'].clear_history() user_history = load_user_history(user_id, limit=10) for role, content in user_history: current_models['llm_generator'].conversation_history.append({ "role": role, "content": content }) if user_history: print(f"[WebSocket] 📚 Loaded {len(user_history)} messages", flush=True) # Check models ready if not models_module.model_state.ready: print(f"[WebSocket] ⏳ Waiting for models...", flush=True) await websocket.send_json({ "type": "status", "message": "Models loading..." }) for i in range(900): if models_module.model_state.ready: await websocket.send_json({ "type": "status", "message": "Models ready!" }) print(f"[WebSocket] ✅ Models ready after {i}s", flush=True) break await asyncio.sleep(1) if not models_module.model_state.ready: print("[WebSocket] ❌ Models timeout", flush=True) await websocket.send_json({"type": "error", "message": "Models timeout"}) await websocket.close(code=1011) return print(f"[WebSocket] ✅ Ready - starting message loop for {username}", flush=True) # ===== MESSAGE LOOP ===== while True: # Check if WebSocket is still connected if websocket.client_state != WebSocketState.CONNECTED: print(f"[WebSocket] 🔌 {username} disconnected (client closed)", flush=True) break try: data = await websocket.receive_json() message_count += 1 msg_type = data.get("type") # Get latest models dynamically models = models_module.get_models() face_processor = models['face_processor'] voice_worker = models['voice_worker'] whisper_worker = models['whisper_worker'] # ===== PREFERENCES ===== if msg_type == "preferences": if "voice" in data: user_preferences["voice"] = data.get("voice", "female") if "language" in data: user_preferences["language"] = data.get("language", "en") if "personality" in data: user_preferences["personality"] = data.get("personality", "therapist") print(f"[Preferences] ✅ {username}: {user_preferences}", flush=True) if "personality" in data and models['llm_generator']: new_personality = data.get("personality", "therapist") models['llm_generator'].set_personality(new_personality) print(f"[LLM] 🎭 Personality changed to: {new_personality}", flush=True) continue # ===== GREETING ===== elif msg_type == "request_greeting": print(f"[Greeting] 🤖 Request from {username}", flush=True) try: lang = user_preferences.get("language", "en") if user_summary: greeting_text = GREETINGS[lang]["returning"].format(username=username) else: greeting_text = GREETINGS[lang]["new"].format(username=username) # Try avatar TTS audio_url = None visemes = None try: voice_pref = user_preferences.get("voice", "female") lang_pref = user_preferences.get("language", "en") avatar_response = requests.post( f"{AVATAR_API}/speak", data={ "text": greeting_text, "voice": voice_pref, "language": lang_pref }, timeout=10 ) if avatar_response.status_code == 200: avatar_data = avatar_response.json() audio_url = avatar_data.get("audio_url") visemes = avatar_data.get("visemes") print(f"[Greeting] ✅ TTS generated", flush=True) else: print(f"[Greeting] ⚠️ TTS failed: {avatar_response.status_code}", flush=True) except Exception as tts_err: print(f"[Greeting] ⚠️ TTS error: {tts_err}", flush=True) response_data = { "type": "llm_response", "text": greeting_text, "emotion": "Neutral", "intensity": 0.5, "is_greeting": True } if audio_url and visemes: response_data["audio_url"] = audio_url response_data["visemes"] = visemes else: response_data["text_only"] = True await websocket.send_json(response_data) print(f"[Greeting] ✅ Sent to {username}", flush=True) save_message(session_id, "assistant", greeting_text, "Neutral") except Exception as err: print(f"[Greeting] ❌ Error: {err}", flush=True) # ===== VIDEO FRAME ===== elif msg_type == "video_frame": video_count += 1 try: img_data = base64.b64decode(data["frame"].split(",")[1]) img = Image.open(io.BytesIO(img_data)) frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) processed_frame, result = face_processor.process_frame(frame) face_emotion = face_processor.get_last_emotion() or "Neutral" face_confidence = face_processor.get_last_confidence() or 0.0 face_probs = face_processor.get_last_probs() face_quality = getattr(face_processor, 'get_last_quality', lambda: 0.5)() await websocket.send_json({ "type": "face_emotion", "emotion": face_emotion, "confidence": face_confidence, "probabilities": face_probs.tolist(), "quality": face_quality }) except Exception as e: if video_count % 100 == 0: print(f"[Video] ⚠️ Error (frame {video_count}): {e}", flush=True) # ===== AUDIO CHUNK (UPDATED FIX) ===== elif msg_type == "audio_chunk": audio_count += 1 try: audio_data = base64.b64decode(data["audio"]) # Validate buffer size if len(audio_data) % 2 != 0 or len(audio_data) == 0: continue # Convert to float32 array audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 # 1. Send to Whisper if whisper_worker: whisper_worker.add_audio(audio_array) # 2. ✅ FIX: Send to Voice Worker (VITAL for server mode) if voice_worker: voice_worker.add_audio(audio_array) # Buffer for simple UI updates (throttled) audio_buffer.append(audio_array) if len(audio_buffer) >= 5: if voice_worker: voice_probs, voice_emotion = voice_worker.get_probs() await websocket.send_json({ "type": "voice_emotion", "emotion": voice_emotion }) audio_buffer.clear() except ValueError as ve: if audio_count % 50 == 0: print(f"[Audio] ⚠️ Buffer error (chunk {audio_count}): {ve}", flush=True) except Exception as e: if audio_count % 50 == 0: print(f"[Audio] ⚠️ Error (chunk {audio_count}): {e}", flush=True) # ===== SPEECH END ===== elif msg_type == "speech_end": transcription = data.get("text", "").strip() print(f"\n{'='*80}", flush=True) print(f"[Speech] 🎤 {username}: '{transcription}'", flush=True) print(f"{'='*80}", flush=True) try: await process_speech_end( websocket, transcription, session_id, user_id, username, user_summary, user_preferences ) except Exception as e: print(f"[Speech] ❌ Error: {e}", flush=True) await websocket.send_json({ "type": "error", "message": f"Error: {str(e)}" }) else: print(f"[WebSocket] ⚠️ Unknown type: {msg_type}", flush=True) except WebSocketDisconnect: print(f"[WebSocket] 🔌 {username} disconnected (WebSocketDisconnect)", flush=True) break except Exception as err: if websocket.client_state != WebSocketState.CONNECTED: print(f"[WebSocket] 🔌 {username} disconnected (error)", flush=True) break print(f"[WebSocket] ⚠️ Loop error: {err}", flush=True) await asyncio.sleep(0.1) except Exception as outer_err: print(f"[WebSocket] ❌ Fatal: {outer_err}", flush=True) finally: print(f"\n[WebSocket] 📊 Session stats for {username or 'Unknown'}:", flush=True) print(f" - Messages: {message_count}", flush=True) print(f" - Video frames: {video_count}", flush=True) print(f" - Audio chunks: {audio_count}", flush=True) if session_id and user_id: try: await generate_session_summary(session_id, user_id) except Exception as e: print(f"[Summary] ❌ Error: {e}", flush=True) print(f"[WebSocket] 👋 Session closed for {username or 'Unknown'}", flush=True) print("="*80 + "\n", flush=True)