File size: 17,649 Bytes
de7b5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""Speech-to-text transcription using Distil-Whisper with Voice Activity Detection (OPTIMIZED FOR NATURAL PAUSES)"""
import time
import threading
import numpy as np
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline as hf_pipeline

from ..config import WHISPER_MODEL, TRANSCRIPTION_BUFFER_SEC

# --- Tunables for turn-taking (OPTIMIZED FOR NATURAL CONVERSATION) ---
HOLD_MS = 1200          # ⭐ LONGER: Wait for natural pauses (was 400)
SHORT_PAUSE_MS = 500    # ⭐ NEW: Brief pause (thinking sounds like "hmm")
MIN_UTTER_MS = 300      # Minimum utterance length
MIN_CHARS = 2           # Minimum characters
ASR_SR = 16000          # Expected sample rate for ASR/VAD
RECENT_SEC_FOR_VAD = 0.5  # How much recent audio to check for speech prob

# ⭐ THINKING SOUNDS - These indicate user is STILL talking, just pausing to think
THINKING_SOUNDS = {
    "um", "uh", "hmm", "mhm", "uh-huh", "mm-hmm",
    "err", "ah", "eh", "umm", "uhh", "hmmm"
}

# ⭐ NOT hallucinations anymore - valid responses!
# Removed: "yeah", "yes", "okay", "ok" - these are real responses


class WhisperTranscriptionWorker:
    """

    Distil-Whisper transcription with Silero VAD-based turn-taking.

    NOW WITH INTELLIGENT PAUSE DETECTION!

    """

    def __init__(self, text_analyzer, model_size=WHISPER_MODEL):
        print(f"\n[Whisper] πŸš€ Initializing...")
        print(f"[Whisper] πŸ“¦ Loading DISTILLED model: {model_size}")

        # Detect device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        print(f"[Whisper] πŸ–₯️  Device: {device} (dtype: {torch_dtype})")

        # Load Whisper model with error handling
        try:
            print(f"[Whisper] πŸ“₯ Downloading/loading Whisper model...")
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
                model_size,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=True,
                use_safetensors=True
            ).to(device)
            print(f"[Whisper] βœ… Whisper model loaded")
        except Exception as e:
            print(f"[Whisper] ❌ Failed to load Whisper model: {e}")
            raise

        try:
            print(f"[Whisper] πŸ“₯ Loading processor...")
            processor = AutoProcessor.from_pretrained(model_size)
            print(f"[Whisper] βœ… Processor loaded")
        except Exception as e:
            print(f"[Whisper] ❌ Failed to load processor: {e}")
            raise

        # Create pipeline
        try:
            print(f"[Whisper] πŸ”§ Building ASR pipeline...")
            self.model = hf_pipeline(
                "automatic-speech-recognition",
                model=model,
                tokenizer=processor.tokenizer,
                feature_extractor=processor.feature_extractor,
                max_new_tokens=80,
                chunk_length_s=15,
                batch_size=32,
                torch_dtype=torch_dtype,
                device=device,
            )
            print(f"[Whisper] βœ… ASR pipeline ready")
        except Exception as e:
            print(f"[Whisper] ❌ Failed to build pipeline: {e}")
            raise

        # Silero VAD
        print("[Whisper] πŸ”§ Loading Silero VAD for speech detection...")
        try:
            self.vad_model, utils = torch.hub.load(
                repo_or_dir='snakers4/silero-vad',
                model='silero_vad',
                force_reload=False,
                onnx=False
            )
            self.get_speech_timestamps = utils[0]
            print("[Whisper] βœ… Silero VAD loaded")
        except Exception as e:
            print(f"[Whisper] ❌ Failed to load VAD: {e}")
            raise

        # State
        self.text_analyzer = text_analyzer
        self.audio_buffer = []
        self.speech_buffer = []
        self.lock = threading.Lock()
        self.running = False

        # Turn-taking timers
        self.is_speaking = False
        self.last_speech_ts = 0.0
        self.utter_start_ts = None
        
        # ⭐ NEW: Thinking detection
        self.consecutive_thinking_sounds = 0
        self.last_thinking_detection = 0.0

        # VAD thresholds
        self.silence_threshold = 0.4
        self.speech_threshold = 0.4

        # Controls
        self.response_callback = None

        # Pause gating
        self.paused = False
        self.pause_lock = threading.Lock()

        # Buffers GC limit
        self.max_chunks = max(1, int(TRANSCRIPTION_BUFFER_SEC / max(RECENT_SEC_FOR_VAD, 0.1)))

        # Stats
        self.transcription_count = 0
        self.total_audio_seconds = 0.0

        print(f"[Whisper] βš™οΈ  Config (NATURAL PAUSE MODE):")
        print(f"[Whisper]   - HOLD_MS: {HOLD_MS}ms (patient waiting)")
        print(f"[Whisper]   - SHORT_PAUSE_MS: {SHORT_PAUSE_MS}ms (thinking detection)")
        print(f"[Whisper]   - MIN_UTTER_MS: {MIN_UTTER_MS}ms")
        print(f"[Whisper]   - Thinking sounds: {THINKING_SOUNDS}")
        print("[Whisper] βœ… Ready! Will wait patiently for you to finish thinking.\n")

    # -------- Public API --------

    def set_response_callback(self, callback):
        self.response_callback = callback
        print(f"[Whisper] βœ… Response callback registered")

    def pause_listening(self):
        """Called by TTS or coordinator: stop reacting while the AI speaks."""
        with self.pause_lock:
            was_paused = self.paused
            self.paused = True
        if not was_paused:
            print("[Whisper] ⏸️  PAUSED (TTS speaking)")

    def resume_listening(self):
        """Called when TTS ends: clear buffers, then listen again."""
        with self.lock:
            audio_cleared = len(self.audio_buffer)
            speech_cleared = len(self.speech_buffer)
            self.audio_buffer = []
            self.speech_buffer = []
        
        with self.pause_lock:
            self.paused = False
        
        # Reset speaking state
        self.is_speaking = False
        self.utter_start_ts = None
        self.last_speech_ts = 0.0
        self.consecutive_thinking_sounds = 0
        
        total_cleared = audio_cleared + speech_cleared
        print(f"[Whisper] ▢️  RESUMED (cleared {total_cleared} chunks)")

    def add_audio(self, audio_chunk: np.ndarray):
        """Ingest mono float32 audio at 16 kHz."""
        with self.pause_lock:
            if self.paused:
                return
        
        if audio_chunk is None or len(audio_chunk) == 0:
            return
        
        with self.lock:
            self.audio_buffer.append(audio_chunk.astype(np.float32, copy=False))
            if len(self.audio_buffer) > self.max_chunks:
                trimmed = len(self.audio_buffer) - self.max_chunks
                self.audio_buffer = self.audio_buffer[-self.max_chunks:]
                if trimmed > 10:
                    print(f"[Whisper] πŸ—‘οΈ  Trimmed {trimmed} old chunks")

    def start(self):
        if self.running:
            print("[Whisper] ⚠️ Already running!")
            return
        
        self.running = True
        self.th = threading.Thread(target=self._transcription_loop, daemon=True)
        self.th.start()
        print("[Whisper] ▢️  Transcription loop started")

    def stop(self):
        if not self.running:
            print("[Whisper] ⚠️ Already stopped!")
            return
        
        self.running = False
        print("[Whisper] ⏹️  Stopping...")
        print(f"[Whisper] πŸ“Š Stats: {self.transcription_count} transcriptions, {self.total_audio_seconds:.1f}s total audio")

    def get_state(self):
        """Debug: get current state"""
        with self.lock:
            audio_len = len(self.audio_buffer)
            speech_len = len(self.speech_buffer)
        with self.pause_lock:
            paused = self.paused
        
        return {
            'paused': paused,
            'is_speaking': self.is_speaking,
            'audio_buffer_len': audio_len,
            'speech_buffer_len': speech_len,
            'transcription_count': self.transcription_count
        }

    # -------- Internals --------

    def _detect_speech_prob(self, audio_recent: np.ndarray) -> float:
        """Silero expects exactly 512 samples @16k for prob()."""
        try:
            required = 512
            if audio_recent.shape[0] < required:
                return 0.0
            audio_recent = audio_recent[-required:]
            audio_tensor = torch.from_numpy(audio_recent).float()
            prob = float(self.vad_model(audio_tensor, ASR_SR).item())
            return prob
        except Exception as e:
            print(f"[Whisper] ⚠️ VAD error: {e}")
            return 0.0

    def _check_for_thinking_sound(self, audio_snippet: np.ndarray) -> bool:
        """

        ⭐ NEW: Quick transcription check to detect thinking sounds.

        Returns True if this is likely "hmm", "umm", etc.

        """
        try:
            duration = len(audio_snippet) / ASR_SR
            if duration < 0.2 or duration > 1.5:  # Thinking sounds are brief
                return False
            
            # Quick transcribe
            result = self.model({"array": audio_snippet, "sampling_rate": ASR_SR})
            text = (result.get("text") or "").strip().lower()
            
            # Check if it's a thinking sound
            words = text.split()
            if len(words) == 1 and words[0] in THINKING_SOUNDS:
                print(f"[Whisper] πŸ€” Detected thinking sound: '{text}' - WAITING for more...")
                return True
            
            return False
        except Exception as e:
            print(f"[Whisper] ⚠️ Thinking detection error: {e}")
            return False

    def _finalize_and_transcribe(self):
        # Collect utterance audio atomically
        with self.lock:
            if not self.speech_buffer:
                return
            audio = np.concatenate(self.speech_buffer, axis=0)
            self.speech_buffer = []

        # Quality gates
        duration = len(audio) / ASR_SR
        if duration < MIN_UTTER_MS / 1000.0:
            print(f"[Whisper] ⏭️  Skipping (too short: {duration:.2f}s)")
            return

        energy = np.abs(audio).mean()
        if energy < 0.003:
            print(f"[Whisper] ⏭️  Skipping (too quiet: energy={energy:.4f})")
            return

        print(f"[Whisper] πŸŽ™οΈ  Transcribing {duration:.2f}s of speech...")
        start_time = time.time()
        
        try:
            result = self.model({"array": audio, "sampling_rate": ASR_SR})
            text = (result.get("text") or "").strip()
            
            transcribe_time = time.time() - start_time
            print(f"[Whisper] ⏱️  Transcription took {transcribe_time:.2f}s")
            
        except Exception as e:
            print(f"[Whisper] ❌ Transcription error: {e}")
            import traceback
            traceback.print_exc()
            return

        if not text or len(text) < MIN_CHARS:
            print(f"[Whisper] ⏭️  Skipping (short text: '{text}')")
            return

        # Filter ONLY isolated thinking sounds with low energy
        t_low = text.lower().strip()
        word_count = len(t_low.split())
        
        if word_count == 1 and t_low in THINKING_SOUNDS and energy < 0.004:
            print(f"[Whisper] 🚫 Filtered isolated thinking sound: '{text}'")
            return

        # Valid transcription!
        self.transcription_count += 1
        self.total_audio_seconds += duration
        print(f"[Whisper] βœ… Transcribed #{self.transcription_count}: '{text}'")
        
        # Send to text analyzer
        try:
            if self.text_analyzer:
                self.text_analyzer.analyze(text)
        except Exception as e:
            print(f"[Whisper] ⚠️ Text analyzer error: {e}")

        # Send to callback
        if self.response_callback:
            with self.pause_lock:
                if self.paused:
                    print(f"[Whisper] ⚠️ Skipping callback (paused mid-transcription)")
                    return
            
            try:
                self.response_callback(text)
            except Exception as e:
                print(f"[Whisper] ❌ Callback error: {e}")
                import traceback
                traceback.print_exc()

    def _transcription_loop(self):
        """

        ⭐ ENHANCED: Real-time VAD with intelligent pause detection.

        Waits patiently during thinking sounds and mid-sentence pauses.

        """
        poll = 0.05  # 50ms loop
        loop_count = 0
        
        print("[Whisper] πŸ”„ Transcription loop running (PATIENT MODE)...")
        
        while self.running:
            loop_count += 1
            time.sleep(poll)

            if loop_count % 200 == 0:
                state = self.get_state()
                print(f"[Whisper] πŸ’“ Heartbeat: speaking={state['is_speaking']}, "
                      f"transcriptions={state['transcription_count']}")

            with self.pause_lock:
                if self.paused:
                    continue

            # Snapshot recent audio
            with self.lock:
                if not self.audio_buffer:
                    continue
                hop_est = max(1, int(RECENT_SEC_FOR_VAD / max(poll, 0.01)))
                recent_chunks = self.audio_buffer[-hop_est:]
                
                try:
                    recent_audio = np.concatenate(recent_chunks, axis=0)
                except Exception as e:
                    print(f"[Whisper] ⚠️ Concatenate error: {e}")
                    continue

            # VAD speech prob
            speech_prob = self._detect_speech_prob(recent_audio)
            now = time.time()

            if speech_prob > self.speech_threshold:
                # Speaking detected
                if not self.is_speaking:
                    self.is_speaking = True
                    self.utter_start_ts = now
                    print(f"[Whisper] 🎀 Speech detected (prob: {speech_prob:.2f})")

                self.last_speech_ts = now
                self.consecutive_thinking_sounds = 0  # Reset thinking counter

                # Move audio to speech buffer
                with self.lock:
                    if self.audio_buffer:
                        self.speech_buffer.extend(self.audio_buffer)
                        self.audio_buffer = []

            elif self.is_speaking:
                # Silence while we were speaking
                silence_ms = (now - self.last_speech_ts) * 1000.0
                utter_ms = (self.last_speech_ts - (self.utter_start_ts or now)) * 1000.0

                # Drain remainder
                with self.lock:
                    if self.audio_buffer:
                        self.speech_buffer.extend(self.audio_buffer)
                        self.audio_buffer = []

                # ⭐ SMART PAUSE DETECTION
                if SHORT_PAUSE_MS <= silence_ms < HOLD_MS:
                    # Short pause - check if it's thinking sound
                    if (now - self.last_thinking_detection) > 1.0:  # Don't check too often
                        with self.lock:
                            if self.speech_buffer:
                                recent_speech = np.concatenate(self.speech_buffer[-10:], axis=0)
                                if self._check_for_thinking_sound(recent_speech[-int(ASR_SR * 1.0):]):
                                    # It's a thinking sound! Reset timer and keep waiting
                                    self.last_speech_ts = now - (SHORT_PAUSE_MS / 2000.0)  # Give more time
                                    self.consecutive_thinking_sounds += 1
                                    self.last_thinking_detection = now
                                    print(f"[Whisper] ⏳ Thinking pause detected ({self.consecutive_thinking_sounds}x) - extending wait time")
                                    continue

                # Final decision
                if silence_ms >= HOLD_MS and utter_ms >= MIN_UTTER_MS:
                    # Long enough silence - finalize
                    print(f"[Whisper] πŸ”‡ Silence {silence_ms:.0f}ms β†’ finalizing (utter {utter_ms:.0f}ms)")
                    self.is_speaking = False
                    self.utter_start_ts = None
                    self.consecutive_thinking_sounds = 0
                    self._finalize_and_transcribe()
                elif silence_ms >= HOLD_MS:
                    # Too short utterance
                    print(f"[Whisper] ⏭️  Ignoring short utterance ({utter_ms:.0f}ms)")
                    self.is_speaking = False
                    self.utter_start_ts = None
                    self.consecutive_thinking_sounds = 0
                    with self.lock:
                        self.speech_buffer = []

            else:
                # Idle: trim old buffers
                with self.lock:
                    if len(self.audio_buffer) > self.max_chunks:
                        self.audio_buffer = self.audio_buffer[-self.max_chunks:]