michon commited on
Commit
18b1b4d
·
1 Parent(s): 7659fc4

chat history try 2

Browse files
mrrrme/backend_server.py CHANGED
@@ -47,6 +47,10 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
47
  from fastapi.middleware.cors import CORSMiddleware
48
  import requests
49
  from PIL import Image
 
 
 
 
50
 
51
  # Check GPU
52
  if not torch.cuda.is_available():
@@ -165,6 +169,15 @@ async def health():
165
  async def websocket_endpoint(websocket: WebSocket):
166
  await websocket.accept()
167
  print("[WebSocket] ✅ Client connected!")
 
 
 
 
 
 
 
 
 
168
 
169
  # Wait for models to load if needed
170
  if not models_ready:
@@ -208,6 +221,19 @@ async def websocket_endpoint(websocket: WebSocket):
208
  print(f"[Preferences] Updated: voice={user_preferences.get('voice')}, language={user_preferences.get('language')}")
209
  continue
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  # ============ VIDEO FRAME ============
212
  if msg_type == "video_frame":
213
  try:
@@ -291,10 +317,21 @@ async def websocket_endpoint(websocket: WebSocket):
291
 
292
  print(f"[Fusion] Face: {face_emotion}, Voice: {voice_emotion}, Fused: {fused_emotion}")
293
 
294
- # Generate LLM response
 
 
 
 
 
 
 
 
 
 
295
  response_text = llm_generator.generate_response(
296
  fused_emotion, face_emotion, voice_emotion,
297
- transcription, force=True, intensity=intensity
 
298
  )
299
 
300
  print(f"[LLM] Response: '{response_text}'")
@@ -333,9 +370,28 @@ async def websocket_endpoint(websocket: WebSocket):
333
  print(f"[Speech Processing] Error: {e}")
334
  import traceback
335
  traceback.print_exc()
 
 
 
 
 
 
 
 
 
336
 
337
  except WebSocketDisconnect:
338
  print("[WebSocket] ❌ Client disconnected")
 
 
 
 
 
 
 
 
 
 
339
  except Exception as e:
340
  print(f"[WebSocket] Error: {e}")
341
  import traceback
 
47
  from fastapi.middleware.cors import CORSMiddleware
48
  import requests
49
  from PIL import Image
50
+ import uuid
51
+
52
+ # Chat history helper
53
+ from mrrrme.utils import chat_history
54
 
55
  # Check GPU
56
  if not torch.cuda.is_available():
 
169
  async def websocket_endpoint(websocket: WebSocket):
170
  await websocket.accept()
171
  print("[WebSocket] ✅ Client connected!")
172
+ # create a session id; this will be used if the client doesn't provide an identity
173
+ session_id = str(uuid.uuid4())
174
+ user_key = f"session_{session_id}"
175
+
176
+ # send session id to client so it can store/identify later
177
+ try:
178
+ await websocket.send_json({"type": "session", "session_id": session_id})
179
+ except Exception:
180
+ pass
181
 
182
  # Wait for models to load if needed
183
  if not models_ready:
 
221
  print(f"[Preferences] Updated: voice={user_preferences.get('voice')}, language={user_preferences.get('language')}")
222
  continue
223
 
224
+ # ============ IDENTIFY / SET USER ============
225
+ if msg_type == "identify":
226
+ # client can send { type: 'identify', user_id: 'some-id' }
227
+ incoming_user = data.get("user_id")
228
+ if incoming_user:
229
+ user_key = f"user_{incoming_user}"
230
+ print(f"[Session] Identified user: {incoming_user}")
231
+ # Load existing summary and send to client
232
+ summary = chat_history.load_summary(user_key)
233
+ if summary:
234
+ await websocket.send_json({"type": "summary", "summary": summary})
235
+ continue
236
+
237
  # ============ VIDEO FRAME ============
238
  if msg_type == "video_frame":
239
  try:
 
317
 
318
  print(f"[Fusion] Face: {face_emotion}, Voice: {voice_emotion}, Fused: {fused_emotion}")
319
 
320
+ # Load per-user history and pass as context to LLM
321
+ history = chat_history.load_history(user_key)
322
+ # Optionally include prior saved summary as system message
323
+ context_messages = []
324
+ saved_summary = chat_history.load_summary(user_key)
325
+ if saved_summary:
326
+ context_messages.append({"role": "system", "content": f"Previous session summary: {saved_summary}"})
327
+ # include prior messages as context
328
+ context_messages.extend(history)
329
+
330
+ # Generate LLM response with per-user context
331
  response_text = llm_generator.generate_response(
332
  fused_emotion, face_emotion, voice_emotion,
333
+ transcription, force=True, intensity=intensity,
334
+ context_messages=context_messages
335
  )
336
 
337
  print(f"[LLM] Response: '{response_text}'")
 
370
  print(f"[Speech Processing] Error: {e}")
371
  import traceback
372
  traceback.print_exc()
373
+ finally:
374
+ # persist the user <-> assistant messages into per-user history
375
+ try:
376
+ if transcription:
377
+ chat_history.append_message(user_key, "user", transcription)
378
+ if response_text:
379
+ chat_history.append_message(user_key, "assistant", response_text)
380
+ except Exception as e:
381
+ print(f"[History] Failed to persist history: {e}")
382
 
383
  except WebSocketDisconnect:
384
  print("[WebSocket] ❌ Client disconnected")
385
+ # On disconnect, summarize and persist summary for user
386
+ try:
387
+ history = chat_history.load_history(user_key)
388
+ if llm_generator and history:
389
+ summary = llm_generator.summarize_history(history)
390
+ if summary:
391
+ chat_history.save_summary(user_key, summary)
392
+ print(f"[History] Saved summary for {user_key}: {summary}")
393
+ except Exception as e:
394
+ print(f"[History] Error summarizing on disconnect: {e}")
395
  except Exception as e:
396
  print(f"[WebSocket] Error: {e}")
397
  import traceback
mrrrme/nlp/llm_generator_groq.py CHANGED
@@ -198,7 +198,8 @@ NEVER: Generic questions, "You seem [emotion]", robotic phrases
198
  ALWAYS: Match emotion naturally, be genuine"""
199
 
200
  def generate_response(self, fused_emotion, face_emotion, voice_emotion,
201
- user_text, force=False, intensity=0.5, is_masking=False):
 
202
  """Generate response via Groq API"""
203
  if not force and not user_text:
204
  return ""
@@ -210,8 +211,14 @@ ALWAYS: Match emotion naturally, be genuine"""
210
 
211
  messages = [{"role": "system", "content": system_prompt}]
212
 
213
- for msg in self.conversation_history[-6:]:
214
- messages.append(msg)
 
 
 
 
 
 
215
 
216
  messages.append({"role": "user", "content": user_text})
217
 
@@ -230,17 +237,19 @@ ALWAYS: Match emotion naturally, be genuine"""
230
  response_text = response.choices[0].message.content.strip()
231
  response_text = self._clean_response(response_text)
232
 
233
- self.conversation_history.append({
234
- "role": "user",
235
- "content": user_text
236
- })
237
- self.conversation_history.append({
238
- "role": "assistant",
239
- "content": response_text
240
- })
241
-
242
- if len(self.conversation_history) > 20:
243
- self.conversation_history = self.conversation_history[-20:]
 
 
244
 
245
  self.last_response = response_text
246
 
@@ -296,4 +305,47 @@ ALWAYS: Match emotion naturally, be genuine"""
296
 
297
  def clear_history(self):
298
  self.conversation_history = []
299
- print("[LLM] 🗑️ Conversation history cleared")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ALWAYS: Match emotion naturally, be genuine"""
199
 
200
  def generate_response(self, fused_emotion, face_emotion, voice_emotion,
201
+ user_text, force=False, intensity=0.5, is_masking=False,
202
+ context_messages=None):
203
  """Generate response via Groq API"""
204
  if not force and not user_text:
205
  return ""
 
211
 
212
  messages = [{"role": "system", "content": system_prompt}]
213
 
214
+ # Use provided context messages (per-user/session) if available,
215
+ # otherwise fall back to the generator's internal history.
216
+ if context_messages is not None:
217
+ for msg in context_messages[-6:]:
218
+ messages.append(msg)
219
+ else:
220
+ for msg in self.conversation_history[-6:]:
221
+ messages.append(msg)
222
 
223
  messages.append({"role": "user", "content": user_text})
224
 
 
237
  response_text = response.choices[0].message.content.strip()
238
  response_text = self._clean_response(response_text)
239
 
240
+ # If a context_messages list was provided, do NOT mutate the
241
+ # global conversation_history here (caller should persist per-user history).
242
+ if context_messages is None:
243
+ self.conversation_history.append({
244
+ "role": "user",
245
+ "content": user_text
246
+ })
247
+ self.conversation_history.append({
248
+ "role": "assistant",
249
+ "content": response_text
250
+ })
251
+ if len(self.conversation_history) > 20:
252
+ self.conversation_history = self.conversation_history[-20:]
253
 
254
  self.last_response = response_text
255
 
 
305
 
306
  def clear_history(self):
307
  self.conversation_history = []
308
+ print("[LLM] 🗑️ Conversation history cleared")
309
+
310
+ def summarize_history(self, messages=None, max_tokens=120):
311
+ """Return a concise summary of the provided messages (or current convo)."""
312
+ if messages is None:
313
+ messages = self.conversation_history
314
+
315
+ if not messages:
316
+ return ""
317
+
318
+ # Build summarization system prompt
319
+ system_prompt = (
320
+ "You are an assistant that summarizes short conversations for later context. "
321
+ "Produce a brief (one to two sentence) summary that captures the user's main concerns, topics, and emotional tone. "
322
+ "Keep it concise and focused so it can be used as memory the next time the user connects."
323
+ )
324
+
325
+ msg_list = [{"role": "system", "content": system_prompt}]
326
+
327
+ # Include the last ~40 messages to summarize, but kept small
328
+ for m in messages[-80:]:
329
+ # ensure roles are 'user' or 'assistant'
330
+ role = m.get("role", "user")
331
+ content = m.get("content", "")
332
+ msg_list.append({"role": role, "content": content})
333
+
334
+ try:
335
+ response = self.client.chat.completions.create(
336
+ messages=msg_list,
337
+ model=self.model_name,
338
+ temperature=0.1,
339
+ max_tokens=max_tokens,
340
+ top_p=0.9,
341
+ )
342
+
343
+ summary_text = response.choices[0].message.content.strip()
344
+ # Clean a bit
345
+ if '\n' in summary_text:
346
+ summary_text = summary_text.split('\n')[0]
347
+ return summary_text
348
+
349
+ except Exception as e:
350
+ print(f"[LLM] ❌ Summarization failed: {e}")
351
+ return ""
mrrrme/utils/chat_history.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple per-user/session chat history and summary storage."""
2
+ import os
3
+ import json
4
+ from typing import List, Dict
5
+
6
+ BASE_DIR = os.path.join(os.getcwd(), "chat_histories")
7
+ os.makedirs(BASE_DIR, exist_ok=True)
8
+
9
+
10
+ def _filepath(key: str) -> str:
11
+ safe_key = key.replace(os.path.sep, "_")
12
+ return os.path.join(BASE_DIR, f"{safe_key}.json")
13
+
14
+
15
+ def append_message(key: str, role: str, content: str):
16
+ """Append a message to the history for `key` (user or session)."""
17
+ path = _filepath(key)
18
+ if os.path.exists(path):
19
+ try:
20
+ with open(path, "r", encoding="utf-8") as f:
21
+ data = json.load(f)
22
+ except Exception:
23
+ data = {"messages": [], "summary": None}
24
+ else:
25
+ data = {"messages": [], "summary": None}
26
+
27
+ data["messages"].append({"role": role, "content": content})
28
+
29
+ # keep history bounded to last 200 messages
30
+ if len(data["messages"]) > 200:
31
+ data["messages"] = data["messages"][-200:]
32
+
33
+ with open(path, "w", encoding="utf-8") as f:
34
+ json.dump(data, f, ensure_ascii=False, indent=2)
35
+
36
+
37
+ def load_history(key: str) -> List[Dict]:
38
+ path = _filepath(key)
39
+ if not os.path.exists(path):
40
+ return []
41
+ try:
42
+ with open(path, "r", encoding="utf-8") as f:
43
+ data = json.load(f)
44
+ return data.get("messages", [])
45
+ except Exception:
46
+ return []
47
+
48
+
49
+ def save_summary(key: str, summary: str):
50
+ path = _filepath(key)
51
+ if os.path.exists(path):
52
+ try:
53
+ with open(path, "r", encoding="utf-8") as f:
54
+ data = json.load(f)
55
+ except Exception:
56
+ data = {"messages": [], "summary": None}
57
+ else:
58
+ data = {"messages": [], "summary": None}
59
+
60
+ data["summary"] = summary
61
+
62
+ with open(path, "w", encoding="utf-8") as f:
63
+ json.dump(data, f, ensure_ascii=False, indent=2)
64
+
65
+
66
+ def load_summary(key: str):
67
+ path = _filepath(key)
68
+ if not os.path.exists(path):
69
+ return None
70
+ try:
71
+ with open(path, "r", encoding="utf-8") as f:
72
+ data = json.load(f)
73
+ return data.get("summary")
74
+ except Exception:
75
+ return None