michon's picture
fix voice of tone1
d423871
"""
MrrrMe Backend - WebSocket Handler (PRODUCTION - FIXED DISCONNECTION HANDLING)
"""
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
import asyncio
import base64
import numpy as np
import cv2
import io
from PIL import Image
import requests
# Import the module for dynamic model access
from . import models as models_module
from .session.manager import validate_token, save_message, load_user_history
from .session.summary import generate_session_summary
from .auth.database import get_db_connection
from .utils.helpers import get_avatar_api_url
from .config import GREETINGS
from .processing.speech import process_speech_end
AVATAR_API = get_avatar_api_url()
async def websocket_endpoint(websocket: WebSocket):
"""Main WebSocket endpoint handler"""
print("\n" + "="*80, flush=True)
print("[WebSocket] πŸ”Œ NEW CONNECTION", flush=True)
print("="*80, flush=True)
try:
await websocket.accept()
print("[WebSocket] βœ… Connection accepted", flush=True)
except Exception as e:
print(f"[WebSocket] ❌ Failed to accept: {e}", flush=True)
return
# Session variables
session_data = None
user_summary = None
session_id = None
user_id = None
username = None
audio_buffer = []
user_preferences = {"voice": "female", "language": "en", "personality": "therapist"}
# Message counters (for stats, not logging)
message_count = 0
video_count = 0
audio_count = 0
try:
# ===== AUTHENTICATION =====
print("[WebSocket] ⏳ Waiting for auth message...", flush=True)
try:
auth_msg = await asyncio.wait_for(websocket.receive_json(), timeout=10.0)
except asyncio.TimeoutError:
print("[WebSocket] ❌ Auth timeout", flush=True)
await websocket.close(code=1008, reason="Auth timeout")
return
if auth_msg.get("type") != "auth":
print(f"[WebSocket] ❌ Wrong type: {auth_msg.get('type')}", flush=True)
await websocket.send_json({"type": "error", "message": "Auth required"})
await websocket.close(code=1008)
return
token = auth_msg.get("token")
if not token:
print("[WebSocket] ❌ No token", flush=True)
await websocket.send_json({"type": "error", "message": "No token"})
await websocket.close(code=1008)
return
# Validate
session_data = validate_token(token)
if not session_data:
print("[WebSocket] ❌ Invalid token", flush=True)
await websocket.send_json({"type": "error", "message": "Invalid session"})
await websocket.close(code=1008)
return
session_id = session_data['session_id']
user_id = session_data['user_id']
username = session_data['username']
print(f"[WebSocket] βœ… Authenticated: {username}", flush=True)
# Get summary
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT summary_text FROM user_summaries WHERE user_id = ?", (user_id,))
summary_row = cursor.fetchone()
user_summary = summary_row[0] if summary_row else None
conn.close()
except Exception as e:
print(f"[WebSocket] ⚠️ Summary error: {e}", flush=True)
user_summary = None
# Send auth confirmation
await websocket.send_json({
"type": "authenticated",
"username": username,
"summary": user_summary
})
# Load history
current_models = models_module.get_models()
if current_models['llm_generator']:
current_models['llm_generator'].clear_history()
user_history = load_user_history(user_id, limit=10)
for role, content in user_history:
current_models['llm_generator'].conversation_history.append({
"role": role,
"content": content
})
if user_history:
print(f"[WebSocket] πŸ“š Loaded {len(user_history)} messages", flush=True)
# Check models ready
if not models_module.model_state.ready:
print(f"[WebSocket] ⏳ Waiting for models...", flush=True)
await websocket.send_json({
"type": "status",
"message": "Models loading..."
})
for i in range(900):
if models_module.model_state.ready:
await websocket.send_json({
"type": "status",
"message": "Models ready!"
})
print(f"[WebSocket] βœ… Models ready after {i}s", flush=True)
break
await asyncio.sleep(1)
if not models_module.model_state.ready:
print("[WebSocket] ❌ Models timeout", flush=True)
await websocket.send_json({"type": "error", "message": "Models timeout"})
await websocket.close(code=1011)
return
print(f"[WebSocket] βœ… Ready - starting message loop for {username}", flush=True)
# ===== MESSAGE LOOP =====
while True:
# Check if WebSocket is still connected
if websocket.client_state != WebSocketState.CONNECTED:
print(f"[WebSocket] πŸ”Œ {username} disconnected (client closed)", flush=True)
break
try:
data = await websocket.receive_json()
message_count += 1
msg_type = data.get("type")
# Get latest models dynamically
models = models_module.get_models()
face_processor = models['face_processor']
voice_worker = models['voice_worker']
whisper_worker = models['whisper_worker']
# ===== PREFERENCES =====
if msg_type == "preferences":
if "voice" in data:
user_preferences["voice"] = data.get("voice", "female")
if "language" in data:
user_preferences["language"] = data.get("language", "en")
if "personality" in data:
user_preferences["personality"] = data.get("personality", "therapist")
print(f"[Preferences] βœ… {username}: {user_preferences}", flush=True)
if "personality" in data and models['llm_generator']:
new_personality = data.get("personality", "therapist")
models['llm_generator'].set_personality(new_personality)
print(f"[LLM] 🎭 Personality changed to: {new_personality}", flush=True)
continue
# ===== GREETING =====
elif msg_type == "request_greeting":
print(f"[Greeting] πŸ€– Request from {username}", flush=True)
try:
lang = user_preferences.get("language", "en")
if user_summary:
greeting_text = GREETINGS[lang]["returning"].format(username=username)
else:
greeting_text = GREETINGS[lang]["new"].format(username=username)
# Try avatar TTS
audio_url = None
visemes = None
try:
voice_pref = user_preferences.get("voice", "female")
lang_pref = user_preferences.get("language", "en")
avatar_response = requests.post(
f"{AVATAR_API}/speak",
data={
"text": greeting_text,
"voice": voice_pref,
"language": lang_pref
},
timeout=10
)
if avatar_response.status_code == 200:
avatar_data = avatar_response.json()
audio_url = avatar_data.get("audio_url")
visemes = avatar_data.get("visemes")
print(f"[Greeting] βœ… TTS generated", flush=True)
else:
print(f"[Greeting] ⚠️ TTS failed: {avatar_response.status_code}", flush=True)
except Exception as tts_err:
print(f"[Greeting] ⚠️ TTS error: {tts_err}", flush=True)
response_data = {
"type": "llm_response",
"text": greeting_text,
"emotion": "Neutral",
"intensity": 0.5,
"is_greeting": True
}
if audio_url and visemes:
response_data["audio_url"] = audio_url
response_data["visemes"] = visemes
else:
response_data["text_only"] = True
await websocket.send_json(response_data)
print(f"[Greeting] βœ… Sent to {username}", flush=True)
save_message(session_id, "assistant", greeting_text, "Neutral")
except Exception as err:
print(f"[Greeting] ❌ Error: {err}", flush=True)
# ===== VIDEO FRAME =====
elif msg_type == "video_frame":
video_count += 1
try:
img_data = base64.b64decode(data["frame"].split(",")[1])
img = Image.open(io.BytesIO(img_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
processed_frame, result = face_processor.process_frame(frame)
face_emotion = face_processor.get_last_emotion() or "Neutral"
face_confidence = face_processor.get_last_confidence() or 0.0
face_probs = face_processor.get_last_probs()
face_quality = getattr(face_processor, 'get_last_quality', lambda: 0.5)()
await websocket.send_json({
"type": "face_emotion",
"emotion": face_emotion,
"confidence": face_confidence,
"probabilities": face_probs.tolist(),
"quality": face_quality
})
except Exception as e:
if video_count % 100 == 0:
print(f"[Video] ⚠️ Error (frame {video_count}): {e}", flush=True)
# ===== AUDIO CHUNK (UPDATED FIX) =====
elif msg_type == "audio_chunk":
audio_count += 1
try:
audio_data = base64.b64decode(data["audio"])
# Validate buffer size
if len(audio_data) % 2 != 0 or len(audio_data) == 0:
continue
# Convert to float32 array
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
# 1. Send to Whisper
if whisper_worker:
whisper_worker.add_audio(audio_array)
# 2. βœ… FIX: Send to Voice Worker (VITAL for server mode)
if voice_worker:
voice_worker.add_audio(audio_array)
# Buffer for simple UI updates (throttled)
audio_buffer.append(audio_array)
if len(audio_buffer) >= 5:
if voice_worker:
voice_probs, voice_emotion = voice_worker.get_probs()
await websocket.send_json({
"type": "voice_emotion",
"emotion": voice_emotion
})
audio_buffer.clear()
except ValueError as ve:
if audio_count % 50 == 0:
print(f"[Audio] ⚠️ Buffer error (chunk {audio_count}): {ve}", flush=True)
except Exception as e:
if audio_count % 50 == 0:
print(f"[Audio] ⚠️ Error (chunk {audio_count}): {e}", flush=True)
# ===== SPEECH END =====
elif msg_type == "speech_end":
transcription = data.get("text", "").strip()
print(f"\n{'='*80}", flush=True)
print(f"[Speech] 🎀 {username}: '{transcription}'", flush=True)
print(f"{'='*80}", flush=True)
try:
await process_speech_end(
websocket, transcription, session_id, user_id,
username, user_summary, user_preferences
)
except Exception as e:
print(f"[Speech] ❌ Error: {e}", flush=True)
await websocket.send_json({
"type": "error",
"message": f"Error: {str(e)}"
})
else:
print(f"[WebSocket] ⚠️ Unknown type: {msg_type}", flush=True)
except WebSocketDisconnect:
print(f"[WebSocket] πŸ”Œ {username} disconnected (WebSocketDisconnect)", flush=True)
break
except Exception as err:
if websocket.client_state != WebSocketState.CONNECTED:
print(f"[WebSocket] πŸ”Œ {username} disconnected (error)", flush=True)
break
print(f"[WebSocket] ⚠️ Loop error: {err}", flush=True)
await asyncio.sleep(0.1)
except Exception as outer_err:
print(f"[WebSocket] ❌ Fatal: {outer_err}", flush=True)
finally:
print(f"\n[WebSocket] πŸ“Š Session stats for {username or 'Unknown'}:", flush=True)
print(f" - Messages: {message_count}", flush=True)
print(f" - Video frames: {video_count}", flush=True)
print(f" - Audio chunks: {audio_count}", flush=True)
if session_id and user_id:
try:
await generate_session_summary(session_id, user_id)
except Exception as e:
print(f"[Summary] ❌ Error: {e}", flush=True)
print(f"[WebSocket] πŸ‘‹ Session closed for {username or 'Unknown'}", flush=True)
print("="*80 + "\n", flush=True)