Merge pull request #57 from pollen-robotics/add-github-workflows
Browse files- .github/workflows/tests.yml +39 -0
- .github/workflows/typecheck.yml +29 -0
- .gitignore +1 -0
- pyproject.toml +15 -3
- src/reachy_mini_conversation_demo/audio/head_wobbler.py +7 -5
- src/reachy_mini_conversation_demo/audio/speech_tapper.py +11 -9
- src/reachy_mini_conversation_demo/camera_worker.py +26 -25
- src/reachy_mini_conversation_demo/config.py +4 -4
- src/reachy_mini_conversation_demo/console.py +7 -6
- src/reachy_mini_conversation_demo/dance_emotion_moves.py +21 -18
- src/reachy_mini_conversation_demo/main.py +5 -4
- src/reachy_mini_conversation_demo/moves.py +29 -28
- src/reachy_mini_conversation_demo/openai_realtime.py +54 -43
- src/reachy_mini_conversation_demo/tools.py +35 -33
- src/reachy_mini_conversation_demo/utils.py +8 -6
- src/reachy_mini_conversation_demo/vision/processors.py +16 -14
- src/reachy_mini_conversation_demo/vision/yolo_head_tracker.py +15 -9
- tests/audio/test_head_wobbler.py +3 -2
- uv.lock +0 -0
.github/workflows/tests.yml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Tests
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
pull_request:
|
| 5 |
+
|
| 6 |
+
permissions:
|
| 7 |
+
contents: read
|
| 8 |
+
|
| 9 |
+
concurrency:
|
| 10 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
| 11 |
+
cancel-in-progress: true
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
tests:
|
| 15 |
+
name: pytest (py${{ matrix.python-version }})
|
| 16 |
+
runs-on: ubuntu-latest
|
| 17 |
+
timeout-minutes: 15
|
| 18 |
+
strategy:
|
| 19 |
+
fail-fast: false
|
| 20 |
+
matrix:
|
| 21 |
+
python-version: ["3.12"]
|
| 22 |
+
|
| 23 |
+
steps:
|
| 24 |
+
- uses: actions/checkout@v4
|
| 25 |
+
|
| 26 |
+
- uses: actions/setup-python@v5
|
| 27 |
+
with:
|
| 28 |
+
python-version: ${{ matrix.python-version }}
|
| 29 |
+
|
| 30 |
+
- uses: astral-sh/setup-uv@v5
|
| 31 |
+
|
| 32 |
+
- name: Install (locked)
|
| 33 |
+
env:
|
| 34 |
+
GIT_LFS_SKIP_SMUDGE: "1"
|
| 35 |
+
run: |
|
| 36 |
+
uv sync --frozen --group dev --extra all_vision
|
| 37 |
+
|
| 38 |
+
- name: Run tests
|
| 39 |
+
run: uv run pytest -q
|
.github/workflows/typecheck.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Type check
|
| 2 |
+
|
| 3 |
+
on: [push, pull_request]
|
| 4 |
+
|
| 5 |
+
permissions:
|
| 6 |
+
contents: read
|
| 7 |
+
|
| 8 |
+
concurrency:
|
| 9 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
| 10 |
+
cancel-in-progress: true
|
| 11 |
+
|
| 12 |
+
jobs:
|
| 13 |
+
mypy:
|
| 14 |
+
runs-on: ubuntu-latest
|
| 15 |
+
timeout-minutes: 10
|
| 16 |
+
steps:
|
| 17 |
+
- uses: actions/checkout@v4
|
| 18 |
+
|
| 19 |
+
- uses: actions/setup-python@v5
|
| 20 |
+
with:
|
| 21 |
+
python-version: "3.12"
|
| 22 |
+
|
| 23 |
+
- uses: astral-sh/setup-uv@v5
|
| 24 |
+
|
| 25 |
+
- name: Install deps (locked) incl. vision extras
|
| 26 |
+
run: uv sync --frozen --group dev --extra all_vision
|
| 27 |
+
|
| 28 |
+
- name: Run mypy
|
| 29 |
+
run: uv run mypy --pretty --show-error-codes .
|
.gitignore
CHANGED
|
@@ -29,6 +29,7 @@ coverage.xml
|
|
| 29 |
|
| 30 |
# Linting and formatting
|
| 31 |
.ruff_cache/
|
|
|
|
| 32 |
|
| 33 |
# IDE
|
| 34 |
.vscode/
|
|
|
|
| 29 |
|
| 30 |
# Linting and formatting
|
| 31 |
.ruff_cache/
|
| 32 |
+
.mypy_cache/
|
| 33 |
|
| 34 |
# IDE
|
| 35 |
.vscode/
|
pyproject.toml
CHANGED
|
@@ -12,7 +12,7 @@ requires-python = ">=3.10"
|
|
| 12 |
dependencies = [
|
| 13 |
#Media
|
| 14 |
"aiortc>=1.13.0",
|
| 15 |
-
"fastrtc
|
| 16 |
"gradio>=5.49.0",
|
| 17 |
"huggingface_hub>=0.34.4",
|
| 18 |
"opencv-python>=4.12.0.88",
|
|
@@ -23,7 +23,7 @@ dependencies = [
|
|
| 23 |
#OpenAI
|
| 24 |
"openai>=2.1",
|
| 25 |
|
| 26 |
-
#Reachy mini
|
| 27 |
"reachy_mini_dances_library",
|
| 28 |
"reachy_mini_toolbox",
|
| 29 |
"reachy_mini>=1.0.0.rc4",
|
|
@@ -40,7 +40,11 @@ all_vision = [
|
|
| 40 |
]
|
| 41 |
|
| 42 |
[dependency-groups]
|
| 43 |
-
dev = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
[project.scripts]
|
| 46 |
reachy-mini-conversation-demo = "reachy_mini_conversation_demo.main:main"
|
|
@@ -88,3 +92,11 @@ quote-style = "double"
|
|
| 88 |
indent-style = "space"
|
| 89 |
skip-magic-trailing-comma = false
|
| 90 |
line-ending = "auto"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
dependencies = [
|
| 13 |
#Media
|
| 14 |
"aiortc>=1.13.0",
|
| 15 |
+
"fastrtc>=0.0.33",
|
| 16 |
"gradio>=5.49.0",
|
| 17 |
"huggingface_hub>=0.34.4",
|
| 18 |
"opencv-python>=4.12.0.88",
|
|
|
|
| 23 |
#OpenAI
|
| 24 |
"openai>=2.1",
|
| 25 |
|
| 26 |
+
#Reachy mini
|
| 27 |
"reachy_mini_dances_library",
|
| 28 |
"reachy_mini_toolbox",
|
| 29 |
"reachy_mini>=1.0.0.rc4",
|
|
|
|
| 40 |
]
|
| 41 |
|
| 42 |
[dependency-groups]
|
| 43 |
+
dev = [
|
| 44 |
+
"pytest",
|
| 45 |
+
"ruff==0.12.0",
|
| 46 |
+
"mypy==1.18.2",
|
| 47 |
+
]
|
| 48 |
|
| 49 |
[project.scripts]
|
| 50 |
reachy-mini-conversation-demo = "reachy_mini_conversation_demo.main:main"
|
|
|
|
| 92 |
indent-style = "space"
|
| 93 |
skip-magic-trailing-comma = false
|
| 94 |
line-ending = "auto"
|
| 95 |
+
|
| 96 |
+
[tool.mypy]
|
| 97 |
+
python_version = "3.12"
|
| 98 |
+
files = ["src/"]
|
| 99 |
+
ignore_missing_imports = true
|
| 100 |
+
strict = true
|
| 101 |
+
show_error_codes = true
|
| 102 |
+
warn_unused_ignores = true
|
src/reachy_mini_conversation_demo/audio/head_wobbler.py
CHANGED
|
@@ -5,9 +5,11 @@ import queue
|
|
| 5 |
import base64
|
| 6 |
import logging
|
| 7 |
import threading
|
| 8 |
-
from typing import Tuple
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
|
| 12 |
from reachy_mini_conversation_demo.audio.speech_tapper import HOP_MS, SwayRollRT
|
| 13 |
|
|
@@ -20,13 +22,13 @@ logger = logging.getLogger(__name__)
|
|
| 20 |
class HeadWobbler:
|
| 21 |
"""Converts audio deltas (base64) into head movement offsets."""
|
| 22 |
|
| 23 |
-
def __init__(self, set_speech_offsets):
|
| 24 |
"""Initialize the head wobbler."""
|
| 25 |
self._apply_offsets = set_speech_offsets
|
| 26 |
-
self._base_ts:
|
| 27 |
self._hops_done: int = 0
|
| 28 |
|
| 29 |
-
self.audio_queue: queue.Queue[Tuple[int, int, np.
|
| 30 |
self.sway = SwayRollRT()
|
| 31 |
|
| 32 |
# Synchronization primitives
|
|
@@ -35,7 +37,7 @@ class HeadWobbler:
|
|
| 35 |
self._generation = 0
|
| 36 |
|
| 37 |
self._stop_event = threading.Event()
|
| 38 |
-
self._thread:
|
| 39 |
|
| 40 |
def feed(self, delta_b64: str) -> None:
|
| 41 |
"""Thread-safe: push audio into the consumer queue."""
|
|
|
|
| 5 |
import base64
|
| 6 |
import logging
|
| 7 |
import threading
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from collections.abc import Callable
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
+
from numpy.typing import NDArray
|
| 13 |
|
| 14 |
from reachy_mini_conversation_demo.audio.speech_tapper import HOP_MS, SwayRollRT
|
| 15 |
|
|
|
|
| 22 |
class HeadWobbler:
|
| 23 |
"""Converts audio deltas (base64) into head movement offsets."""
|
| 24 |
|
| 25 |
+
def __init__(self, set_speech_offsets: Callable[[Tuple[float, float, float, float, float, float]], None]) -> None:
|
| 26 |
"""Initialize the head wobbler."""
|
| 27 |
self._apply_offsets = set_speech_offsets
|
| 28 |
+
self._base_ts: float | None = None
|
| 29 |
self._hops_done: int = 0
|
| 30 |
|
| 31 |
+
self.audio_queue: "queue.Queue[Tuple[int, int, NDArray[np.int16]]]" = queue.Queue()
|
| 32 |
self.sway = SwayRollRT()
|
| 33 |
|
| 34 |
# Synchronization primitives
|
|
|
|
| 37 |
self._generation = 0
|
| 38 |
|
| 39 |
self._stop_event = threading.Event()
|
| 40 |
+
self._thread: threading.Thread | None = None
|
| 41 |
|
| 42 |
def feed(self, delta_b64: str) -> None:
|
| 43 |
"""Thread-safe: push audio into the consumer queue."""
|
src/reachy_mini_conversation_demo/audio/speech_tapper.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import math
|
| 3 |
-
from typing import Dict, List
|
| 4 |
from itertools import islice
|
| 5 |
from collections import deque
|
| 6 |
|
| 7 |
import numpy as np
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
# Tunables
|
|
@@ -48,7 +49,7 @@ SWAY_ATTACK_FR = max(1, int(SWAY_ATTACK_MS / HOP_MS))
|
|
| 48 |
SWAY_RELEASE_FR = max(1, int(SWAY_RELEASE_MS / HOP_MS))
|
| 49 |
|
| 50 |
|
| 51 |
-
def _rms_dbfs(x: np.
|
| 52 |
"""Root-mean-square in dBFS for float32 mono array in [-1,1]."""
|
| 53 |
# numerically stable rms (avoid overflow)
|
| 54 |
x = x.astype(np.float32, copy=False)
|
|
@@ -66,7 +67,7 @@ def _loudness_gain(db: float, offset: float = SENS_DB_OFFSET) -> float:
|
|
| 66 |
return t**LOUDNESS_GAMMA if LOUDNESS_GAMMA != 1.0 else t
|
| 67 |
|
| 68 |
|
| 69 |
-
def _to_float32_mono(x:
|
| 70 |
"""Convert arbitrary PCM array to float32 mono in [-1,1].
|
| 71 |
|
| 72 |
Accepts shapes: (N,), (1,N), (N,1), (C,N), (N,C).
|
|
@@ -94,7 +95,7 @@ def _to_float32_mono(x: np.ndarray) -> np.ndarray:
|
|
| 94 |
return a.astype(np.float32) / (scale if scale != 0.0 else 1.0)
|
| 95 |
|
| 96 |
|
| 97 |
-
def _resample_linear(x: np.
|
| 98 |
"""Lightweight linear resampler for short buffers."""
|
| 99 |
if sr_in == sr_out or x.size == 0:
|
| 100 |
return x
|
|
@@ -118,8 +119,8 @@ class SwayRollRT:
|
|
| 118 |
def __init__(self, rng_seed: int = 7):
|
| 119 |
"""Initialize state."""
|
| 120 |
self._seed = int(rng_seed)
|
| 121 |
-
self.samples = deque(maxlen=10 * SR) # sliding window for VAD/env
|
| 122 |
-
self.carry = np.zeros(0, dtype=np.float32)
|
| 123 |
|
| 124 |
self.vad_on = False
|
| 125 |
self.vad_above = 0
|
|
@@ -150,7 +151,7 @@ class SwayRollRT:
|
|
| 150 |
self.sway_down = 0
|
| 151 |
self.t = 0.0
|
| 152 |
|
| 153 |
-
def feed(self, pcm:
|
| 154 |
"""Stream in PCM chunk. Returns a list of sway dicts, one per hop (HOP_MS).
|
| 155 |
|
| 156 |
Args:
|
|
@@ -177,7 +178,8 @@ class SwayRollRT:
|
|
| 177 |
|
| 178 |
while self.carry.size >= HOP:
|
| 179 |
hop = self.carry[:HOP]
|
| 180 |
-
|
|
|
|
| 181 |
|
| 182 |
# keep sliding window for VAD/env computation
|
| 183 |
# (deque accepts any iterable; list() for small HOP is fine)
|
|
@@ -260,7 +262,7 @@ class SwayRollRT:
|
|
| 260 |
"x_mm": x_mm,
|
| 261 |
"y_mm": y_mm,
|
| 262 |
"z_mm": z_mm,
|
| 263 |
-
}
|
| 264 |
)
|
| 265 |
|
| 266 |
return out
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import math
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
from itertools import islice
|
| 5 |
from collections import deque
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
+
from numpy.typing import NDArray
|
| 9 |
|
| 10 |
|
| 11 |
# Tunables
|
|
|
|
| 49 |
SWAY_RELEASE_FR = max(1, int(SWAY_RELEASE_MS / HOP_MS))
|
| 50 |
|
| 51 |
|
| 52 |
+
def _rms_dbfs(x: NDArray[np.float32]) -> float:
|
| 53 |
"""Root-mean-square in dBFS for float32 mono array in [-1,1]."""
|
| 54 |
# numerically stable rms (avoid overflow)
|
| 55 |
x = x.astype(np.float32, copy=False)
|
|
|
|
| 67 |
return t**LOUDNESS_GAMMA if LOUDNESS_GAMMA != 1.0 else t
|
| 68 |
|
| 69 |
|
| 70 |
+
def _to_float32_mono(x: NDArray[Any]) -> NDArray[np.float32]:
|
| 71 |
"""Convert arbitrary PCM array to float32 mono in [-1,1].
|
| 72 |
|
| 73 |
Accepts shapes: (N,), (1,N), (N,1), (C,N), (N,C).
|
|
|
|
| 95 |
return a.astype(np.float32) / (scale if scale != 0.0 else 1.0)
|
| 96 |
|
| 97 |
|
| 98 |
+
def _resample_linear(x: NDArray[np.float32], sr_in: int, sr_out: int) -> NDArray[np.float32]:
|
| 99 |
"""Lightweight linear resampler for short buffers."""
|
| 100 |
if sr_in == sr_out or x.size == 0:
|
| 101 |
return x
|
|
|
|
| 119 |
def __init__(self, rng_seed: int = 7):
|
| 120 |
"""Initialize state."""
|
| 121 |
self._seed = int(rng_seed)
|
| 122 |
+
self.samples: deque[float] = deque(maxlen=10 * SR) # sliding window for VAD/env
|
| 123 |
+
self.carry: NDArray[np.float32] = np.zeros(0, dtype=np.float32)
|
| 124 |
|
| 125 |
self.vad_on = False
|
| 126 |
self.vad_above = 0
|
|
|
|
| 151 |
self.sway_down = 0
|
| 152 |
self.t = 0.0
|
| 153 |
|
| 154 |
+
def feed(self, pcm: NDArray[Any], sr: int | None) -> List[Dict[str, float]]:
|
| 155 |
"""Stream in PCM chunk. Returns a list of sway dicts, one per hop (HOP_MS).
|
| 156 |
|
| 157 |
Args:
|
|
|
|
| 178 |
|
| 179 |
while self.carry.size >= HOP:
|
| 180 |
hop = self.carry[:HOP]
|
| 181 |
+
remaining: NDArray[np.float32] = self.carry[HOP:]
|
| 182 |
+
self.carry = remaining
|
| 183 |
|
| 184 |
# keep sliding window for VAD/env computation
|
| 185 |
# (deque accepts any iterable; list() for small HOP is fine)
|
|
|
|
| 262 |
"x_mm": x_mm,
|
| 263 |
"y_mm": y_mm,
|
| 264 |
"z_mm": z_mm,
|
| 265 |
+
},
|
| 266 |
)
|
| 267 |
|
| 268 |
return out
|
src/reachy_mini_conversation_demo/camera_worker.py
CHANGED
|
@@ -9,10 +9,11 @@ Ported from main_works.py camera_worker() function to provide:
|
|
| 9 |
import time
|
| 10 |
import logging
|
| 11 |
import threading
|
| 12 |
-
from typing import
|
| 13 |
|
| 14 |
import cv2
|
| 15 |
import numpy as np
|
|
|
|
| 16 |
from scipy.spatial.transform import Rotation as R
|
| 17 |
|
| 18 |
from reachy_mini import ReachyMini
|
|
@@ -25,20 +26,20 @@ logger = logging.getLogger(__name__)
|
|
| 25 |
class CameraWorker:
|
| 26 |
"""Thread-safe camera worker with frame buffering and face tracking."""
|
| 27 |
|
| 28 |
-
def __init__(self, reachy_mini: ReachyMini, head_tracker=None):
|
| 29 |
"""Initialize."""
|
| 30 |
self.reachy_mini = reachy_mini
|
| 31 |
self.head_tracker = head_tracker
|
| 32 |
|
| 33 |
# Thread-safe frame storage
|
| 34 |
-
self.latest_frame:
|
| 35 |
self.frame_lock = threading.Lock()
|
| 36 |
self._stop_event = threading.Event()
|
| 37 |
-
self._thread:
|
| 38 |
|
| 39 |
# Face tracking state
|
| 40 |
self.is_head_tracking_enabled = True
|
| 41 |
-
self.face_tracking_offsets = [
|
| 42 |
0.0,
|
| 43 |
0.0,
|
| 44 |
0.0,
|
|
@@ -49,31 +50,31 @@ class CameraWorker:
|
|
| 49 |
self.face_tracking_lock = threading.Lock()
|
| 50 |
|
| 51 |
# Face tracking timing variables (same as main_works.py)
|
| 52 |
-
self.last_face_detected_time:
|
| 53 |
-
self.interpolation_start_time:
|
| 54 |
-
self.interpolation_start_pose:
|
| 55 |
self.face_lost_delay = 2.0 # seconds to wait before starting interpolation
|
| 56 |
self.interpolation_duration = 1.0 # seconds to interpolate back to neutral
|
| 57 |
|
| 58 |
# Track state changes
|
| 59 |
self.previous_head_tracking_state = self.is_head_tracking_enabled
|
| 60 |
|
| 61 |
-
def get_latest_frame(self) ->
|
| 62 |
"""Get the latest frame (thread-safe)."""
|
| 63 |
with self.frame_lock:
|
| 64 |
if self.latest_frame is None:
|
| 65 |
return None
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
return frame
|
| 70 |
|
| 71 |
def get_face_tracking_offsets(
|
| 72 |
self,
|
| 73 |
) -> Tuple[float, float, float, float, float, float]:
|
| 74 |
"""Get current face tracking offsets (thread-safe)."""
|
| 75 |
with self.face_tracking_lock:
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
def set_head_tracking_enabled(self, enabled: bool) -> None:
|
| 79 |
"""Enable/disable head tracking."""
|
|
@@ -168,12 +169,11 @@ class CameraWorker:
|
|
| 168 |
rotation[2], # roll, pitch, yaw
|
| 169 |
]
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
if
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
pass
|
| 177 |
|
| 178 |
# Handle smooth interpolation (works for both face-lost and tracking-disabled cases)
|
| 179 |
if self.last_face_detected_time is not None:
|
|
@@ -188,11 +188,12 @@ class CameraWorker:
|
|
| 188 |
current_translation = self.face_tracking_offsets[:3]
|
| 189 |
current_rotation_euler = self.face_tracking_offsets[3:]
|
| 190 |
# Convert to 4x4 pose matrix
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
"xyz", current_rotation_euler
|
| 195 |
).as_matrix()
|
|
|
|
| 196 |
|
| 197 |
# Calculate interpolation progress (t from 0 to 1)
|
| 198 |
elapsed_interpolation = current_time - self.interpolation_start_time
|
|
@@ -200,7 +201,7 @@ class CameraWorker:
|
|
| 200 |
|
| 201 |
# Interpolate between current pose and neutral pose
|
| 202 |
interpolated_pose = linear_pose_interpolation(
|
| 203 |
-
self.interpolation_start_pose, neutral_pose, t
|
| 204 |
)
|
| 205 |
|
| 206 |
# Extract translation and rotation from interpolated pose
|
|
|
|
| 9 |
import time
|
| 10 |
import logging
|
| 11 |
import threading
|
| 12 |
+
from typing import Any, List, Tuple
|
| 13 |
|
| 14 |
import cv2
|
| 15 |
import numpy as np
|
| 16 |
+
from numpy.typing import NDArray
|
| 17 |
from scipy.spatial.transform import Rotation as R
|
| 18 |
|
| 19 |
from reachy_mini import ReachyMini
|
|
|
|
| 26 |
class CameraWorker:
|
| 27 |
"""Thread-safe camera worker with frame buffering and face tracking."""
|
| 28 |
|
| 29 |
+
def __init__(self, reachy_mini: ReachyMini, head_tracker: Any = None) -> None:
|
| 30 |
"""Initialize."""
|
| 31 |
self.reachy_mini = reachy_mini
|
| 32 |
self.head_tracker = head_tracker
|
| 33 |
|
| 34 |
# Thread-safe frame storage
|
| 35 |
+
self.latest_frame: NDArray[np.uint8] | None = None
|
| 36 |
self.frame_lock = threading.Lock()
|
| 37 |
self._stop_event = threading.Event()
|
| 38 |
+
self._thread: threading.Thread | None = None
|
| 39 |
|
| 40 |
# Face tracking state
|
| 41 |
self.is_head_tracking_enabled = True
|
| 42 |
+
self.face_tracking_offsets: List[float] = [
|
| 43 |
0.0,
|
| 44 |
0.0,
|
| 45 |
0.0,
|
|
|
|
| 50 |
self.face_tracking_lock = threading.Lock()
|
| 51 |
|
| 52 |
# Face tracking timing variables (same as main_works.py)
|
| 53 |
+
self.last_face_detected_time: float | None = None
|
| 54 |
+
self.interpolation_start_time: float | None = None
|
| 55 |
+
self.interpolation_start_pose: NDArray[np.float32] | None = None
|
| 56 |
self.face_lost_delay = 2.0 # seconds to wait before starting interpolation
|
| 57 |
self.interpolation_duration = 1.0 # seconds to interpolate back to neutral
|
| 58 |
|
| 59 |
# Track state changes
|
| 60 |
self.previous_head_tracking_state = self.is_head_tracking_enabled
|
| 61 |
|
| 62 |
+
def get_latest_frame(self) -> NDArray[np.uint8] | None:
|
| 63 |
"""Get the latest frame (thread-safe)."""
|
| 64 |
with self.frame_lock:
|
| 65 |
if self.latest_frame is None:
|
| 66 |
return None
|
| 67 |
+
frame = self.latest_frame.copy()
|
| 68 |
+
frame_rgb: NDArray[np.uint8] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # type: ignore[assignment]
|
| 69 |
+
return frame_rgb
|
|
|
|
| 70 |
|
| 71 |
def get_face_tracking_offsets(
|
| 72 |
self,
|
| 73 |
) -> Tuple[float, float, float, float, float, float]:
|
| 74 |
"""Get current face tracking offsets (thread-safe)."""
|
| 75 |
with self.face_tracking_lock:
|
| 76 |
+
offsets = self.face_tracking_offsets
|
| 77 |
+
return (offsets[0], offsets[1], offsets[2], offsets[3], offsets[4], offsets[5])
|
| 78 |
|
| 79 |
def set_head_tracking_enabled(self, enabled: bool) -> None:
|
| 80 |
"""Enable/disable head tracking."""
|
|
|
|
| 169 |
rotation[2], # roll, pitch, yaw
|
| 170 |
]
|
| 171 |
|
| 172 |
+
# No face detected while tracking enabled - set face lost timestamp
|
| 173 |
+
elif self.last_face_detected_time is None or self.last_face_detected_time == current_time:
|
| 174 |
+
# Only update if we haven't already set a face lost time
|
| 175 |
+
# (current_time check prevents overriding the disable-triggered timestamp)
|
| 176 |
+
pass
|
|
|
|
| 177 |
|
| 178 |
# Handle smooth interpolation (works for both face-lost and tracking-disabled cases)
|
| 179 |
if self.last_face_detected_time is not None:
|
|
|
|
| 188 |
current_translation = self.face_tracking_offsets[:3]
|
| 189 |
current_rotation_euler = self.face_tracking_offsets[3:]
|
| 190 |
# Convert to 4x4 pose matrix
|
| 191 |
+
pose_matrix = np.eye(4, dtype=np.float32)
|
| 192 |
+
pose_matrix[:3, 3] = current_translation
|
| 193 |
+
pose_matrix[:3, :3] = R.from_euler(
|
| 194 |
+
"xyz", current_rotation_euler,
|
| 195 |
).as_matrix()
|
| 196 |
+
self.interpolation_start_pose = pose_matrix
|
| 197 |
|
| 198 |
# Calculate interpolation progress (t from 0 to 1)
|
| 199 |
elapsed_interpolation = current_time - self.interpolation_start_time
|
|
|
|
| 201 |
|
| 202 |
# Interpolate between current pose and neutral pose
|
| 203 |
interpolated_pose = linear_pose_interpolation(
|
| 204 |
+
self.interpolation_start_pose, neutral_pose, t,
|
| 205 |
)
|
| 206 |
|
| 207 |
# Extract translation and rotation from interpolated pose
|
src/reachy_mini_conversation_demo/config.py
CHANGED
|
@@ -13,13 +13,13 @@ if not env_file.exists():
|
|
| 13 |
raise RuntimeError(
|
| 14 |
".env file not found. Please create one based on .env.example:\n"
|
| 15 |
" cp .env.example .env\n"
|
| 16 |
-
"Then add your OPENAI_API_KEY to the .env file."
|
| 17 |
)
|
| 18 |
|
| 19 |
# Load .env and verify it was loaded successfully
|
| 20 |
if not load_dotenv():
|
| 21 |
raise RuntimeError(
|
| 22 |
-
"Failed to load .env file. Please ensure the file is readable and properly formatted."
|
| 23 |
)
|
| 24 |
|
| 25 |
logger.info("Configuration loaded from .env file")
|
|
@@ -33,11 +33,11 @@ class Config:
|
|
| 33 |
if OPENAI_API_KEY is None:
|
| 34 |
raise RuntimeError(
|
| 35 |
"OPENAI_API_KEY is not set in .env file. Please add it:\n"
|
| 36 |
-
" OPENAI_API_KEY=your_api_key_here"
|
| 37 |
)
|
| 38 |
if not OPENAI_API_KEY.strip():
|
| 39 |
raise RuntimeError(
|
| 40 |
-
"OPENAI_API_KEY is empty in .env file. Please provide a valid API key."
|
| 41 |
)
|
| 42 |
|
| 43 |
# Optional
|
|
|
|
| 13 |
raise RuntimeError(
|
| 14 |
".env file not found. Please create one based on .env.example:\n"
|
| 15 |
" cp .env.example .env\n"
|
| 16 |
+
"Then add your OPENAI_API_KEY to the .env file.",
|
| 17 |
)
|
| 18 |
|
| 19 |
# Load .env and verify it was loaded successfully
|
| 20 |
if not load_dotenv():
|
| 21 |
raise RuntimeError(
|
| 22 |
+
"Failed to load .env file. Please ensure the file is readable and properly formatted.",
|
| 23 |
)
|
| 24 |
|
| 25 |
logger.info("Configuration loaded from .env file")
|
|
|
|
| 33 |
if OPENAI_API_KEY is None:
|
| 34 |
raise RuntimeError(
|
| 35 |
"OPENAI_API_KEY is not set in .env file. Please add it:\n"
|
| 36 |
+
" OPENAI_API_KEY=your_api_key_here",
|
| 37 |
)
|
| 38 |
if not OPENAI_API_KEY.strip():
|
| 39 |
raise RuntimeError(
|
| 40 |
+
"OPENAI_API_KEY is empty in .env file. Please provide a valid API key.",
|
| 41 |
)
|
| 42 |
|
| 43 |
# Optional
|
src/reachy_mini_conversation_demo/console.py
CHANGED
|
@@ -5,6 +5,7 @@ records mic frames to the handler and plays handler audio frames to the speaker.
|
|
| 5 |
|
| 6 |
import asyncio
|
| 7 |
import logging
|
|
|
|
| 8 |
|
| 9 |
import librosa
|
| 10 |
from fastrtc import AdditionalOutputs, audio_to_int16, audio_to_float32
|
|
@@ -24,9 +25,9 @@ class LocalStream:
|
|
| 24 |
self.handler = handler
|
| 25 |
self._robot = robot
|
| 26 |
self._stop_event = asyncio.Event()
|
| 27 |
-
self._tasks = []
|
| 28 |
# Allow the handler to flush the player queue when appropriate.
|
| 29 |
-
self.handler._clear_queue = self.clear_audio_queue
|
| 30 |
|
| 31 |
def launch(self) -> None:
|
| 32 |
"""Start the recorder/player and run the async processing loops."""
|
|
@@ -105,12 +106,12 @@ class LocalStream:
|
|
| 105 |
elif isinstance(handler_output, tuple):
|
| 106 |
input_sample_rate, audio_frame = handler_output
|
| 107 |
device_sample_rate = self._robot.media.get_audio_samplerate()
|
| 108 |
-
|
| 109 |
if input_sample_rate != device_sample_rate:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
)
|
| 113 |
-
self._robot.media.push_audio_sample(
|
| 114 |
|
| 115 |
else:
|
| 116 |
logger.debug("Ignoring output type=%s", type(handler_output).__name__)
|
|
|
|
| 5 |
|
| 6 |
import asyncio
|
| 7 |
import logging
|
| 8 |
+
from typing import List
|
| 9 |
|
| 10 |
import librosa
|
| 11 |
from fastrtc import AdditionalOutputs, audio_to_int16, audio_to_float32
|
|
|
|
| 25 |
self.handler = handler
|
| 26 |
self._robot = robot
|
| 27 |
self._stop_event = asyncio.Event()
|
| 28 |
+
self._tasks: List[asyncio.Task[None]] = []
|
| 29 |
# Allow the handler to flush the player queue when appropriate.
|
| 30 |
+
self.handler._clear_queue = self.clear_audio_queue
|
| 31 |
|
| 32 |
def launch(self) -> None:
|
| 33 |
"""Start the recorder/player and run the async processing loops."""
|
|
|
|
| 106 |
elif isinstance(handler_output, tuple):
|
| 107 |
input_sample_rate, audio_frame = handler_output
|
| 108 |
device_sample_rate = self._robot.media.get_audio_samplerate()
|
| 109 |
+
audio_frame_float = audio_to_float32(audio_frame.squeeze())
|
| 110 |
if input_sample_rate != device_sample_rate:
|
| 111 |
+
audio_frame_float = librosa.resample(
|
| 112 |
+
audio_frame_float, orig_sr=input_sample_rate, target_sr=device_sample_rate,
|
| 113 |
)
|
| 114 |
+
self._robot.media.push_audio_sample(audio_frame_float)
|
| 115 |
|
| 116 |
else:
|
| 117 |
logger.debug("Ignoring output type=%s", type(handler_output).__name__)
|
src/reachy_mini_conversation_demo/dance_emotion_moves.py
CHANGED
|
@@ -9,6 +9,7 @@ import logging
|
|
| 9 |
from typing import Tuple
|
| 10 |
|
| 11 |
import numpy as np
|
|
|
|
| 12 |
|
| 13 |
from reachy_mini.motion.move import Move
|
| 14 |
from reachy_mini.motion.recorded_move import RecordedMoves
|
|
@@ -18,7 +19,7 @@ from reachy_mini_dances_library.dance_move import DanceMove
|
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
|
| 21 |
-
class DanceQueueMove(Move):
|
| 22 |
"""Wrapper for dance moves to work with the movement queue system."""
|
| 23 |
|
| 24 |
def __init__(self, move_name: str):
|
|
@@ -29,9 +30,9 @@ class DanceQueueMove(Move):
|
|
| 29 |
@property
|
| 30 |
def duration(self) -> float:
|
| 31 |
"""Duration property required by official Move interface."""
|
| 32 |
-
return self.dance_move.duration
|
| 33 |
|
| 34 |
-
def evaluate(self, t: float) -> tuple[np.
|
| 35 |
"""Evaluate dance move at time t."""
|
| 36 |
try:
|
| 37 |
# Get the pose from the dance move
|
|
@@ -49,10 +50,10 @@ class DanceQueueMove(Move):
|
|
| 49 |
from reachy_mini.utils import create_head_pose
|
| 50 |
|
| 51 |
neutral_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 52 |
-
return (neutral_head_pose, np.array([0.0, 0.0]), 0.0)
|
| 53 |
|
| 54 |
|
| 55 |
-
class EmotionQueueMove(Move):
|
| 56 |
"""Wrapper for emotion moves to work with the movement queue system."""
|
| 57 |
|
| 58 |
def __init__(self, emotion_name: str, recorded_moves: RecordedMoves):
|
|
@@ -63,9 +64,9 @@ class EmotionQueueMove(Move):
|
|
| 63 |
@property
|
| 64 |
def duration(self) -> float:
|
| 65 |
"""Duration property required by official Move interface."""
|
| 66 |
-
return self.emotion_move.duration
|
| 67 |
|
| 68 |
-
def evaluate(self, t: float) -> tuple[np.
|
| 69 |
"""Evaluate emotion move at time t."""
|
| 70 |
try:
|
| 71 |
# Get the pose from the emotion move
|
|
@@ -83,20 +84,20 @@ class EmotionQueueMove(Move):
|
|
| 83 |
from reachy_mini.utils import create_head_pose
|
| 84 |
|
| 85 |
neutral_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 86 |
-
return (neutral_head_pose, np.array([0.0, 0.0]), 0.0)
|
| 87 |
|
| 88 |
|
| 89 |
-
class GotoQueueMove(Move):
|
| 90 |
"""Wrapper for goto moves to work with the movement queue system."""
|
| 91 |
|
| 92 |
def __init__(
|
| 93 |
self,
|
| 94 |
-
target_head_pose: np.
|
| 95 |
-
start_head_pose: np.
|
| 96 |
target_antennas: Tuple[float, float] = (0, 0),
|
| 97 |
-
start_antennas: Tuple[float, float] = None,
|
| 98 |
target_body_yaw: float = 0,
|
| 99 |
-
start_body_yaw: float = None,
|
| 100 |
duration: float = 1.0,
|
| 101 |
):
|
| 102 |
"""Initialize a GotoQueueMove."""
|
|
@@ -113,7 +114,7 @@ class GotoQueueMove(Move):
|
|
| 113 |
"""Duration property required by official Move interface."""
|
| 114 |
return self._duration
|
| 115 |
|
| 116 |
-
def evaluate(self, t: float) -> tuple[np.
|
| 117 |
"""Evaluate goto move at time t using linear interpolation."""
|
| 118 |
try:
|
| 119 |
from reachy_mini.utils import create_head_pose
|
|
@@ -136,7 +137,8 @@ class GotoQueueMove(Move):
|
|
| 136 |
[
|
| 137 |
self.start_antennas[0] + (self.target_antennas[0] - self.start_antennas[0]) * t_clamped,
|
| 138 |
self.start_antennas[1] + (self.target_antennas[1] - self.start_antennas[1]) * t_clamped,
|
| 139 |
-
]
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
# Interpolate body yaw
|
|
@@ -146,6 +148,7 @@ class GotoQueueMove(Move):
|
|
| 146 |
|
| 147 |
except Exception as e:
|
| 148 |
logger.error(f"Error evaluating goto move at t={t}: {e}")
|
| 149 |
-
# Return target pose on error - convert
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
| 9 |
from typing import Tuple
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
+
from numpy.typing import NDArray
|
| 13 |
|
| 14 |
from reachy_mini.motion.move import Move
|
| 15 |
from reachy_mini.motion.recorded_move import RecordedMoves
|
|
|
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
|
| 22 |
+
class DanceQueueMove(Move): # type: ignore
|
| 23 |
"""Wrapper for dance moves to work with the movement queue system."""
|
| 24 |
|
| 25 |
def __init__(self, move_name: str):
|
|
|
|
| 30 |
@property
|
| 31 |
def duration(self) -> float:
|
| 32 |
"""Duration property required by official Move interface."""
|
| 33 |
+
return float(self.dance_move.duration)
|
| 34 |
|
| 35 |
+
def evaluate(self, t: float) -> tuple[NDArray[np.float64] | None, NDArray[np.float64] | None, float | None]:
|
| 36 |
"""Evaluate dance move at time t."""
|
| 37 |
try:
|
| 38 |
# Get the pose from the dance move
|
|
|
|
| 50 |
from reachy_mini.utils import create_head_pose
|
| 51 |
|
| 52 |
neutral_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 53 |
+
return (neutral_head_pose, np.array([0.0, 0.0], dtype=np.float64), 0.0)
|
| 54 |
|
| 55 |
|
| 56 |
+
class EmotionQueueMove(Move): # type: ignore
|
| 57 |
"""Wrapper for emotion moves to work with the movement queue system."""
|
| 58 |
|
| 59 |
def __init__(self, emotion_name: str, recorded_moves: RecordedMoves):
|
|
|
|
| 64 |
@property
|
| 65 |
def duration(self) -> float:
|
| 66 |
"""Duration property required by official Move interface."""
|
| 67 |
+
return float(self.emotion_move.duration)
|
| 68 |
|
| 69 |
+
def evaluate(self, t: float) -> tuple[NDArray[np.float64] | None, NDArray[np.float64] | None, float | None]:
|
| 70 |
"""Evaluate emotion move at time t."""
|
| 71 |
try:
|
| 72 |
# Get the pose from the emotion move
|
|
|
|
| 84 |
from reachy_mini.utils import create_head_pose
|
| 85 |
|
| 86 |
neutral_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 87 |
+
return (neutral_head_pose, np.array([0.0, 0.0], dtype=np.float64), 0.0)
|
| 88 |
|
| 89 |
|
| 90 |
+
class GotoQueueMove(Move): # type: ignore
|
| 91 |
"""Wrapper for goto moves to work with the movement queue system."""
|
| 92 |
|
| 93 |
def __init__(
|
| 94 |
self,
|
| 95 |
+
target_head_pose: NDArray[np.float32],
|
| 96 |
+
start_head_pose: NDArray[np.float32] | None = None,
|
| 97 |
target_antennas: Tuple[float, float] = (0, 0),
|
| 98 |
+
start_antennas: Tuple[float, float] | None = None,
|
| 99 |
target_body_yaw: float = 0,
|
| 100 |
+
start_body_yaw: float | None = None,
|
| 101 |
duration: float = 1.0,
|
| 102 |
):
|
| 103 |
"""Initialize a GotoQueueMove."""
|
|
|
|
| 114 |
"""Duration property required by official Move interface."""
|
| 115 |
return self._duration
|
| 116 |
|
| 117 |
+
def evaluate(self, t: float) -> tuple[NDArray[np.float64] | None, NDArray[np.float64] | None, float | None]:
|
| 118 |
"""Evaluate goto move at time t using linear interpolation."""
|
| 119 |
try:
|
| 120 |
from reachy_mini.utils import create_head_pose
|
|
|
|
| 137 |
[
|
| 138 |
self.start_antennas[0] + (self.target_antennas[0] - self.start_antennas[0]) * t_clamped,
|
| 139 |
self.start_antennas[1] + (self.target_antennas[1] - self.start_antennas[1]) * t_clamped,
|
| 140 |
+
],
|
| 141 |
+
dtype=np.float64,
|
| 142 |
)
|
| 143 |
|
| 144 |
# Interpolate body yaw
|
|
|
|
| 148 |
|
| 149 |
except Exception as e:
|
| 150 |
logger.error(f"Error evaluating goto move at t={t}: {e}")
|
| 151 |
+
# Return target pose on error - convert to float64
|
| 152 |
+
target_head_pose_f64 = self.target_head_pose.astype(np.float64)
|
| 153 |
+
target_antennas_array = np.array([self.target_antennas[0], self.target_antennas[1]], dtype=np.float64)
|
| 154 |
+
return (target_head_pose_f64, target_antennas_array, self.target_body_yaw)
|
src/reachy_mini_conversation_demo/main.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import sys
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
from fastapi import FastAPI
|
|
@@ -20,13 +21,13 @@ from reachy_mini_conversation_demo.openai_realtime import OpenaiRealtimeHandler
|
|
| 20 |
from reachy_mini_conversation_demo.audio.head_wobbler import HeadWobbler
|
| 21 |
|
| 22 |
|
| 23 |
-
def update_chatbot(chatbot:
|
| 24 |
"""Update the chatbot with AdditionalOutputs."""
|
| 25 |
chatbot.append(response)
|
| 26 |
return chatbot
|
| 27 |
|
| 28 |
|
| 29 |
-
def main():
|
| 30 |
"""Entrypoint for the Reachy Mini conversation demo."""
|
| 31 |
args = parse_args()
|
| 32 |
|
|
@@ -41,7 +42,7 @@ def main():
|
|
| 41 |
# Check if running in simulation mode without --gradio
|
| 42 |
if robot.client.get_status()["simulation_enabled"] and not args.gradio:
|
| 43 |
logger.error(
|
| 44 |
-
"Simulation mode requires Gradio interface. Please use --gradio flag when running in simulation mode."
|
| 45 |
)
|
| 46 |
robot.client.disconnect()
|
| 47 |
sys.exit(1)
|
|
@@ -76,7 +77,7 @@ def main():
|
|
| 76 |
|
| 77 |
handler = OpenaiRealtimeHandler(deps)
|
| 78 |
|
| 79 |
-
stream_manager = None
|
| 80 |
|
| 81 |
if args.gradio:
|
| 82 |
stream = Stream(
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
from fastapi import FastAPI
|
|
|
|
| 21 |
from reachy_mini_conversation_demo.audio.head_wobbler import HeadWobbler
|
| 22 |
|
| 23 |
|
| 24 |
+
def update_chatbot(chatbot: List[Dict[str, Any]], response: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 25 |
"""Update the chatbot with AdditionalOutputs."""
|
| 26 |
chatbot.append(response)
|
| 27 |
return chatbot
|
| 28 |
|
| 29 |
|
| 30 |
+
def main() -> None:
|
| 31 |
"""Entrypoint for the Reachy Mini conversation demo."""
|
| 32 |
args = parse_args()
|
| 33 |
|
|
|
|
| 42 |
# Check if running in simulation mode without --gradio
|
| 43 |
if robot.client.get_status()["simulation_enabled"] and not args.gradio:
|
| 44 |
logger.error(
|
| 45 |
+
"Simulation mode requires Gradio interface. Please use --gradio flag when running in simulation mode.",
|
| 46 |
)
|
| 47 |
robot.client.disconnect()
|
| 48 |
sys.exit(1)
|
|
|
|
| 77 |
|
| 78 |
handler = OpenaiRealtimeHandler(deps)
|
| 79 |
|
| 80 |
+
stream_manager: gr.Blocks | LocalStream | None = None
|
| 81 |
|
| 82 |
if args.gradio:
|
| 83 |
stream = Stream(
|
src/reachy_mini_conversation_demo/moves.py
CHANGED
|
@@ -36,11 +36,12 @@ import time
|
|
| 36 |
import logging
|
| 37 |
import threading
|
| 38 |
from queue import Empty, Queue
|
| 39 |
-
from typing import Any,
|
| 40 |
from collections import deque
|
| 41 |
from dataclasses import dataclass
|
| 42 |
|
| 43 |
import numpy as np
|
|
|
|
| 44 |
|
| 45 |
from reachy_mini import ReachyMini
|
| 46 |
from reachy_mini.utils import create_head_pose
|
|
@@ -57,15 +58,15 @@ logger = logging.getLogger(__name__)
|
|
| 57 |
CONTROL_LOOP_FREQUENCY_HZ = 100.0 # Hz - Target frequency for the movement control loop
|
| 58 |
|
| 59 |
# Type definitions
|
| 60 |
-
FullBodyPose = Tuple[np.
|
| 61 |
|
| 62 |
|
| 63 |
-
class BreathingMove(Move):
|
| 64 |
"""Breathing move with interpolation to neutral and then continuous breathing patterns."""
|
| 65 |
|
| 66 |
def __init__(
|
| 67 |
self,
|
| 68 |
-
interpolation_start_pose: np.
|
| 69 |
interpolation_start_antennas: Tuple[float, float],
|
| 70 |
interpolation_duration: float = 1.0,
|
| 71 |
):
|
|
@@ -96,7 +97,7 @@ class BreathingMove(Move):
|
|
| 96 |
"""Duration property required by official Move interface."""
|
| 97 |
return float("inf") # Continuous breathing (never ends naturally)
|
| 98 |
|
| 99 |
-
def evaluate(self, t: float) -> tuple[np.
|
| 100 |
"""Evaluate breathing move at time t."""
|
| 101 |
if t < self.interpolation_duration:
|
| 102 |
# Phase 1: Interpolate to neutral base position
|
|
@@ -104,13 +105,14 @@ class BreathingMove(Move):
|
|
| 104 |
|
| 105 |
# Interpolate head pose
|
| 106 |
head_pose = linear_pose_interpolation(
|
| 107 |
-
self.interpolation_start_pose, self.neutral_head_pose, interpolation_t
|
| 108 |
)
|
| 109 |
|
| 110 |
# Interpolate antennas
|
| 111 |
-
|
| 112 |
1 - interpolation_t
|
| 113 |
) * self.interpolation_start_antennas + interpolation_t * self.neutral_antennas
|
|
|
|
| 114 |
|
| 115 |
else:
|
| 116 |
# Phase 2: Breathing patterns from neutral base
|
|
@@ -122,7 +124,7 @@ class BreathingMove(Move):
|
|
| 122 |
|
| 123 |
# Antenna sway (opposite directions)
|
| 124 |
antenna_sway = self.antenna_sway_amplitude * np.sin(2 * np.pi * self.antenna_frequency * breathing_time)
|
| 125 |
-
antennas = np.array([antenna_sway, -antenna_sway])
|
| 126 |
|
| 127 |
# Return in official Move interface format: (head_pose, antennas_array, body_yaw)
|
| 128 |
return (head_pose, antennas, 0.0)
|
|
@@ -168,8 +170,8 @@ class MovementState:
|
|
| 168 |
"""State tracking for the movement system."""
|
| 169 |
|
| 170 |
# Primary move state
|
| 171 |
-
current_move:
|
| 172 |
-
move_start_time:
|
| 173 |
last_activity_time: float = 0.0
|
| 174 |
|
| 175 |
# Secondary move state (offsets)
|
|
@@ -191,7 +193,7 @@ class MovementState:
|
|
| 191 |
)
|
| 192 |
|
| 193 |
# Status flags
|
| 194 |
-
last_primary_pose:
|
| 195 |
|
| 196 |
def update_activity(self) -> None:
|
| 197 |
"""Update the last activity time."""
|
|
@@ -242,7 +244,7 @@ class MovementManager:
|
|
| 242 |
def __init__(
|
| 243 |
self,
|
| 244 |
current_robot: ReachyMini,
|
| 245 |
-
camera_worker=None,
|
| 246 |
):
|
| 247 |
"""Initialize movement manager."""
|
| 248 |
self.current_robot = current_robot
|
|
@@ -258,7 +260,7 @@ class MovementManager:
|
|
| 258 |
self.state.last_primary_pose = (neutral_pose, (0.0, 0.0), 0.0)
|
| 259 |
|
| 260 |
# Move queue (primary moves)
|
| 261 |
-
self.move_queue = deque()
|
| 262 |
|
| 263 |
# Configuration
|
| 264 |
self.idle_inactivity_delay = 0.3 # seconds
|
|
@@ -266,7 +268,7 @@ class MovementManager:
|
|
| 266 |
self.target_period = 1.0 / self.target_frequency
|
| 267 |
|
| 268 |
self._stop_event = threading.Event()
|
| 269 |
-
self._thread:
|
| 270 |
self._is_listening = False
|
| 271 |
self._last_commanded_pose: FullBodyPose = clone_full_body_pose(self.state.last_primary_pose)
|
| 272 |
self._listening_antennas: Tuple[float, float] = self._last_commanded_pose[1]
|
|
@@ -281,7 +283,7 @@ class MovementManager:
|
|
| 281 |
self._set_target_err_suppressed = 0
|
| 282 |
|
| 283 |
# Cross-thread signalling
|
| 284 |
-
self._command_queue: Queue[
|
| 285 |
self._speech_offsets_lock = threading.Lock()
|
| 286 |
self._pending_speech_offsets: Tuple[float, float, float, float, float, float] = (
|
| 287 |
0.0,
|
|
@@ -383,7 +385,7 @@ class MovementManager:
|
|
| 383 |
|
| 384 |
def _apply_pending_offsets(self) -> None:
|
| 385 |
"""Apply the most recent speech/face offset updates."""
|
| 386 |
-
speech_offsets:
|
| 387 |
with self._speech_offsets_lock:
|
| 388 |
if self._speech_offsets_dirty:
|
| 389 |
speech_offsets = self._pending_speech_offsets
|
|
@@ -393,7 +395,7 @@ class MovementManager:
|
|
| 393 |
self.state.speech_offsets = speech_offsets
|
| 394 |
self.state.update_activity()
|
| 395 |
|
| 396 |
-
face_offsets:
|
| 397 |
with self._face_offsets_lock:
|
| 398 |
if self._face_offsets_dirty:
|
| 399 |
face_offsets = self._pending_face_offsets
|
|
@@ -549,14 +551,13 @@ class MovementManager:
|
|
| 549 |
)
|
| 550 |
|
| 551 |
self.state.last_primary_pose = clone_full_body_pose(primary_full_body_pose)
|
|
|
|
|
|
|
|
|
|
| 552 |
else:
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
else:
|
| 557 |
-
neutral_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 558 |
-
primary_full_body_pose = (neutral_head_pose, (0.0, 0.0), 0.0)
|
| 559 |
-
self.state.last_primary_pose = clone_full_body_pose(primary_full_body_pose)
|
| 560 |
|
| 561 |
return primary_full_body_pose
|
| 562 |
|
|
@@ -631,7 +632,7 @@ class MovementManager:
|
|
| 631 |
|
| 632 |
return antennas_cmd
|
| 633 |
|
| 634 |
-
def _issue_control_command(self, head: np.
|
| 635 |
"""Send the fused pose to the robot with throttled error logging."""
|
| 636 |
try:
|
| 637 |
self.current_robot.set_target(head=head, antennas=antennas, body_yaw=body_yaw)
|
|
@@ -651,7 +652,7 @@ class MovementManager:
|
|
| 651 |
self._last_commanded_pose = clone_full_body_pose((head, antennas, body_yaw))
|
| 652 |
|
| 653 |
def _update_frequency_stats(
|
| 654 |
-
self, loop_start: float, prev_loop_start: float, stats: LoopFrequencyStats
|
| 655 |
) -> LoopFrequencyStats:
|
| 656 |
"""Update frequency statistics based on the current loop start time."""
|
| 657 |
period = loop_start - prev_loop_start
|
|
@@ -664,7 +665,7 @@ class MovementManager:
|
|
| 664 |
stats.min_freq = min(stats.min_freq, stats.last_freq)
|
| 665 |
return stats
|
| 666 |
|
| 667 |
-
def _schedule_next_tick(self, loop_start: float, stats: LoopFrequencyStats) ->
|
| 668 |
"""Compute sleep time to maintain target frequency and update potential freq."""
|
| 669 |
computation_time = self._now() - loop_start
|
| 670 |
stats.potential_freq = 1.0 / computation_time if computation_time > 0 else float("inf")
|
|
@@ -729,7 +730,7 @@ class MovementManager:
|
|
| 729 |
self._thread = None
|
| 730 |
logger.debug("Move worker stopped")
|
| 731 |
|
| 732 |
-
def get_status(self) ->
|
| 733 |
"""Return a lightweight status snapshot for observability."""
|
| 734 |
with self._status_lock:
|
| 735 |
pose_snapshot = clone_full_body_pose(self._last_commanded_pose)
|
|
|
|
| 36 |
import logging
|
| 37 |
import threading
|
| 38 |
from queue import Empty, Queue
|
| 39 |
+
from typing import Any, Dict, Tuple
|
| 40 |
from collections import deque
|
| 41 |
from dataclasses import dataclass
|
| 42 |
|
| 43 |
import numpy as np
|
| 44 |
+
from numpy.typing import NDArray
|
| 45 |
|
| 46 |
from reachy_mini import ReachyMini
|
| 47 |
from reachy_mini.utils import create_head_pose
|
|
|
|
| 58 |
CONTROL_LOOP_FREQUENCY_HZ = 100.0 # Hz - Target frequency for the movement control loop
|
| 59 |
|
| 60 |
# Type definitions
|
| 61 |
+
FullBodyPose = Tuple[NDArray[np.float32], Tuple[float, float], float] # (head_pose_4x4, antennas, body_yaw)
|
| 62 |
|
| 63 |
|
| 64 |
+
class BreathingMove(Move): # type: ignore
|
| 65 |
"""Breathing move with interpolation to neutral and then continuous breathing patterns."""
|
| 66 |
|
| 67 |
def __init__(
|
| 68 |
self,
|
| 69 |
+
interpolation_start_pose: NDArray[np.float32],
|
| 70 |
interpolation_start_antennas: Tuple[float, float],
|
| 71 |
interpolation_duration: float = 1.0,
|
| 72 |
):
|
|
|
|
| 97 |
"""Duration property required by official Move interface."""
|
| 98 |
return float("inf") # Continuous breathing (never ends naturally)
|
| 99 |
|
| 100 |
+
def evaluate(self, t: float) -> tuple[NDArray[np.float64] | None, NDArray[np.float64] | None, float | None]:
|
| 101 |
"""Evaluate breathing move at time t."""
|
| 102 |
if t < self.interpolation_duration:
|
| 103 |
# Phase 1: Interpolate to neutral base position
|
|
|
|
| 105 |
|
| 106 |
# Interpolate head pose
|
| 107 |
head_pose = linear_pose_interpolation(
|
| 108 |
+
self.interpolation_start_pose, self.neutral_head_pose, interpolation_t,
|
| 109 |
)
|
| 110 |
|
| 111 |
# Interpolate antennas
|
| 112 |
+
antennas_interp = (
|
| 113 |
1 - interpolation_t
|
| 114 |
) * self.interpolation_start_antennas + interpolation_t * self.neutral_antennas
|
| 115 |
+
antennas = antennas_interp.astype(np.float64)
|
| 116 |
|
| 117 |
else:
|
| 118 |
# Phase 2: Breathing patterns from neutral base
|
|
|
|
| 124 |
|
| 125 |
# Antenna sway (opposite directions)
|
| 126 |
antenna_sway = self.antenna_sway_amplitude * np.sin(2 * np.pi * self.antenna_frequency * breathing_time)
|
| 127 |
+
antennas = np.array([antenna_sway, -antenna_sway], dtype=np.float64)
|
| 128 |
|
| 129 |
# Return in official Move interface format: (head_pose, antennas_array, body_yaw)
|
| 130 |
return (head_pose, antennas, 0.0)
|
|
|
|
| 170 |
"""State tracking for the movement system."""
|
| 171 |
|
| 172 |
# Primary move state
|
| 173 |
+
current_move: Move | None = None
|
| 174 |
+
move_start_time: float | None = None
|
| 175 |
last_activity_time: float = 0.0
|
| 176 |
|
| 177 |
# Secondary move state (offsets)
|
|
|
|
| 193 |
)
|
| 194 |
|
| 195 |
# Status flags
|
| 196 |
+
last_primary_pose: FullBodyPose | None = None
|
| 197 |
|
| 198 |
def update_activity(self) -> None:
|
| 199 |
"""Update the last activity time."""
|
|
|
|
| 244 |
def __init__(
|
| 245 |
self,
|
| 246 |
current_robot: ReachyMini,
|
| 247 |
+
camera_worker: "Any" = None,
|
| 248 |
):
|
| 249 |
"""Initialize movement manager."""
|
| 250 |
self.current_robot = current_robot
|
|
|
|
| 260 |
self.state.last_primary_pose = (neutral_pose, (0.0, 0.0), 0.0)
|
| 261 |
|
| 262 |
# Move queue (primary moves)
|
| 263 |
+
self.move_queue: deque[Move] = deque()
|
| 264 |
|
| 265 |
# Configuration
|
| 266 |
self.idle_inactivity_delay = 0.3 # seconds
|
|
|
|
| 268 |
self.target_period = 1.0 / self.target_frequency
|
| 269 |
|
| 270 |
self._stop_event = threading.Event()
|
| 271 |
+
self._thread: threading.Thread | None = None
|
| 272 |
self._is_listening = False
|
| 273 |
self._last_commanded_pose: FullBodyPose = clone_full_body_pose(self.state.last_primary_pose)
|
| 274 |
self._listening_antennas: Tuple[float, float] = self._last_commanded_pose[1]
|
|
|
|
| 283 |
self._set_target_err_suppressed = 0
|
| 284 |
|
| 285 |
# Cross-thread signalling
|
| 286 |
+
self._command_queue: "Queue[Tuple[str, Any]]" = Queue()
|
| 287 |
self._speech_offsets_lock = threading.Lock()
|
| 288 |
self._pending_speech_offsets: Tuple[float, float, float, float, float, float] = (
|
| 289 |
0.0,
|
|
|
|
| 385 |
|
| 386 |
def _apply_pending_offsets(self) -> None:
|
| 387 |
"""Apply the most recent speech/face offset updates."""
|
| 388 |
+
speech_offsets: Tuple[float, float, float, float, float, float] | None = None
|
| 389 |
with self._speech_offsets_lock:
|
| 390 |
if self._speech_offsets_dirty:
|
| 391 |
speech_offsets = self._pending_speech_offsets
|
|
|
|
| 395 |
self.state.speech_offsets = speech_offsets
|
| 396 |
self.state.update_activity()
|
| 397 |
|
| 398 |
+
face_offsets: Tuple[float, float, float, float, float, float] | None = None
|
| 399 |
with self._face_offsets_lock:
|
| 400 |
if self._face_offsets_dirty:
|
| 401 |
face_offsets = self._pending_face_offsets
|
|
|
|
| 551 |
)
|
| 552 |
|
| 553 |
self.state.last_primary_pose = clone_full_body_pose(primary_full_body_pose)
|
| 554 |
+
# Otherwise reuse the last primary pose so we avoid jumps between moves
|
| 555 |
+
elif self.state.last_primary_pose is not None:
|
| 556 |
+
primary_full_body_pose = clone_full_body_pose(self.state.last_primary_pose)
|
| 557 |
else:
|
| 558 |
+
neutral_head_pose = create_head_pose(0, 0, 0, 0, 0, 0, degrees=True)
|
| 559 |
+
primary_full_body_pose = (neutral_head_pose, (0.0, 0.0), 0.0)
|
| 560 |
+
self.state.last_primary_pose = clone_full_body_pose(primary_full_body_pose)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
return primary_full_body_pose
|
| 563 |
|
|
|
|
| 632 |
|
| 633 |
return antennas_cmd
|
| 634 |
|
| 635 |
+
def _issue_control_command(self, head: NDArray[np.float32], antennas: Tuple[float, float], body_yaw: float) -> None:
|
| 636 |
"""Send the fused pose to the robot with throttled error logging."""
|
| 637 |
try:
|
| 638 |
self.current_robot.set_target(head=head, antennas=antennas, body_yaw=body_yaw)
|
|
|
|
| 652 |
self._last_commanded_pose = clone_full_body_pose((head, antennas, body_yaw))
|
| 653 |
|
| 654 |
def _update_frequency_stats(
|
| 655 |
+
self, loop_start: float, prev_loop_start: float, stats: LoopFrequencyStats,
|
| 656 |
) -> LoopFrequencyStats:
|
| 657 |
"""Update frequency statistics based on the current loop start time."""
|
| 658 |
period = loop_start - prev_loop_start
|
|
|
|
| 665 |
stats.min_freq = min(stats.min_freq, stats.last_freq)
|
| 666 |
return stats
|
| 667 |
|
| 668 |
+
def _schedule_next_tick(self, loop_start: float, stats: LoopFrequencyStats) -> Tuple[float, LoopFrequencyStats]:
|
| 669 |
"""Compute sleep time to maintain target frequency and update potential freq."""
|
| 670 |
computation_time = self._now() - loop_start
|
| 671 |
stats.potential_freq = 1.0 / computation_time if computation_time > 0 else float("inf")
|
|
|
|
| 730 |
self._thread = None
|
| 731 |
logger.debug("Move worker stopped")
|
| 732 |
|
| 733 |
+
def get_status(self) -> Dict[str, Any]:
|
| 734 |
"""Return a lightweight status snapshot for observability."""
|
| 735 |
with self._status_lock:
|
| 736 |
pose_snapshot = clone_full_body_pose(self._last_commanded_pose)
|
src/reachy_mini_conversation_demo/openai_realtime.py
CHANGED
|
@@ -2,12 +2,14 @@ import json
|
|
| 2 |
import base64
|
| 3 |
import asyncio
|
| 4 |
import logging
|
|
|
|
| 5 |
from datetime import datetime
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import gradio as gr
|
| 9 |
from openai import AsyncOpenAI
|
| 10 |
from fastrtc import AdditionalOutputs, AsyncStreamHandler, wait_for_item
|
|
|
|
| 11 |
|
| 12 |
from reachy_mini_conversation_demo.tools import (
|
| 13 |
ALL_TOOL_SPECS,
|
|
@@ -33,18 +35,18 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 33 |
)
|
| 34 |
self.deps = deps
|
| 35 |
|
| 36 |
-
self.connection = None
|
| 37 |
-
self.output_queue = asyncio.Queue()
|
| 38 |
|
| 39 |
self.last_activity_time = asyncio.get_event_loop().time()
|
| 40 |
self.start_time = asyncio.get_event_loop().time()
|
| 41 |
self.is_idle_tool_call = False
|
| 42 |
|
| 43 |
-
def copy(self):
|
| 44 |
"""Create a copy of the handler."""
|
| 45 |
return OpenaiRealtimeHandler(self.deps)
|
| 46 |
|
| 47 |
-
async def start_up(self):
|
| 48 |
"""Start the handler."""
|
| 49 |
self.client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
|
| 50 |
async with self.client.beta.realtime.connect(model=config.MODEL_NAME) as conn:
|
|
@@ -59,10 +61,10 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 59 |
},
|
| 60 |
"voice": "ballad",
|
| 61 |
"instructions": SESSION_INSTRUCTIONS,
|
| 62 |
-
"tools": ALL_TOOL_SPECS,
|
| 63 |
"tool_choice": "auto",
|
| 64 |
"temperature": 0.7,
|
| 65 |
-
}
|
| 66 |
)
|
| 67 |
|
| 68 |
# Manage event received from the openai server
|
|
@@ -70,9 +72,10 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 70 |
async for event in self.connection:
|
| 71 |
logger.debug(f"OpenAI event: {event.type}")
|
| 72 |
if event.type == "input_audio_buffer.speech_started":
|
| 73 |
-
if hasattr(self,
|
| 74 |
self._clear_queue()
|
| 75 |
-
self.deps.head_wobbler
|
|
|
|
| 76 |
self.deps.movement_manager.set_listening(True)
|
| 77 |
logger.debug("User speech started")
|
| 78 |
|
|
@@ -83,7 +86,8 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 83 |
if event.type in ("response.audio.completed", "response.completed"):
|
| 84 |
# Doesn't seem to be called
|
| 85 |
logger.debug("response completed")
|
| 86 |
-
self.deps.head_wobbler
|
|
|
|
| 87 |
|
| 88 |
if event.type == "response.created":
|
| 89 |
logger.debug("Response created")
|
|
@@ -91,7 +95,6 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 91 |
if event.type == "response.done":
|
| 92 |
# Doesn't mean the audio is done playing
|
| 93 |
logger.debug("Response done")
|
| 94 |
-
pass
|
| 95 |
|
| 96 |
if event.type == "conversation.item.input_audio_transcription.completed":
|
| 97 |
logger.debug(f"User transcript: {event.transcript}")
|
|
@@ -102,7 +105,8 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 102 |
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": event.transcript}))
|
| 103 |
|
| 104 |
if event.type == "response.audio.delta":
|
| 105 |
-
self.deps.head_wobbler
|
|
|
|
| 106 |
self.last_activity_time = asyncio.get_event_loop().time()
|
| 107 |
logger.debug("last activity time updated to %s", self.last_activity_time)
|
| 108 |
await self.output_queue.put(
|
|
@@ -118,6 +122,10 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 118 |
args_json_str = getattr(event, "arguments", None)
|
| 119 |
call_id = getattr(event, "call_id", None)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
try:
|
| 122 |
tool_result = await dispatch_tool_call(tool_name, args_json_str, self.deps)
|
| 123 |
logger.debug("Tool '%s' executed successfully", tool_name)
|
|
@@ -127,22 +135,23 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 127 |
tool_result = {"error": str(e)}
|
| 128 |
|
| 129 |
# send the tool result back
|
| 130 |
-
|
| 131 |
-
item
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
| 137 |
|
| 138 |
await self.output_queue.put(
|
| 139 |
AdditionalOutputs(
|
| 140 |
{
|
| 141 |
"role": "assistant",
|
| 142 |
"content": json.dumps(tool_result),
|
| 143 |
-
"metadata": {"title": "🛠️ Used tool "
|
| 144 |
},
|
| 145 |
-
)
|
| 146 |
)
|
| 147 |
|
| 148 |
if tool_name == "camera" and "b64_im" in tool_result:
|
|
@@ -157,37 +166,39 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 157 |
"role": "user",
|
| 158 |
"content": [
|
| 159 |
{
|
| 160 |
-
"type": "input_image",
|
| 161 |
"image_url": f"data:image/jpeg;base64,{b64_im}",
|
| 162 |
-
}
|
| 163 |
],
|
| 164 |
-
}
|
| 165 |
)
|
| 166 |
logger.info("Added camera image to conversation")
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
| 177 |
)
|
| 178 |
-
)
|
| 179 |
|
| 180 |
if not self.is_idle_tool_call:
|
| 181 |
await self.connection.response.create(
|
| 182 |
response={
|
| 183 |
-
"instructions": "Use the tool result just returned and answer concisely in speech."
|
| 184 |
-
}
|
| 185 |
)
|
| 186 |
else:
|
| 187 |
self.is_idle_tool_call = False
|
| 188 |
|
| 189 |
# re synchronize the head wobble after a tool call that may have taken some time
|
| 190 |
-
self.deps.head_wobbler
|
|
|
|
| 191 |
|
| 192 |
# server error
|
| 193 |
if event.type == "error":
|
|
@@ -197,7 +208,7 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 197 |
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": f"[error] {msg}"}))
|
| 198 |
|
| 199 |
# Microphone receive
|
| 200 |
-
async def receive(self, frame:
|
| 201 |
"""Receive audio frame from the microphone and send it to the openai server."""
|
| 202 |
if not self.connection:
|
| 203 |
return
|
|
@@ -205,9 +216,9 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 205 |
array = array.squeeze()
|
| 206 |
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
| 207 |
# Fills the input audio buffer to be sent to the server
|
| 208 |
-
await self.connection.input_audio_buffer.append(audio=audio_message)
|
| 209 |
|
| 210 |
-
async def emit(self) ->
|
| 211 |
"""Emit audio frame to be played by the speaker."""
|
| 212 |
# sends to the stream the stuff put in the output queue by the openai event handler
|
| 213 |
# This is called periodically by the fastrtc Stream
|
|
@@ -219,7 +230,7 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 219 |
|
| 220 |
self.last_activity_time = asyncio.get_event_loop().time() # avoid repeated resets
|
| 221 |
|
| 222 |
-
return await wait_for_item(self.output_queue)
|
| 223 |
|
| 224 |
async def shutdown(self) -> None:
|
| 225 |
"""Shutdown the handler."""
|
|
@@ -227,7 +238,7 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 227 |
await self.connection.close()
|
| 228 |
self.connection = None
|
| 229 |
|
| 230 |
-
def format_timestamp(self):
|
| 231 |
"""Format current timestamp with date, time and elapsed seconds."""
|
| 232 |
current_time = asyncio.get_event_loop().time()
|
| 233 |
elapsed_seconds = current_time - self.start_time
|
|
@@ -236,7 +247,7 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 236 |
|
| 237 |
|
| 238 |
|
| 239 |
-
async def send_idle_signal(self, idle_duration) -> None:
|
| 240 |
"""Send an idle signal to the openai server."""
|
| 241 |
logger.debug("Sending idle signal")
|
| 242 |
self.is_idle_tool_call = True
|
|
@@ -249,12 +260,12 @@ class OpenaiRealtimeHandler(AsyncStreamHandler):
|
|
| 249 |
"type": "message",
|
| 250 |
"role": "user",
|
| 251 |
"content": [{"type": "input_text", "text": timestamp_msg}],
|
| 252 |
-
}
|
| 253 |
)
|
| 254 |
await self.connection.response.create(
|
| 255 |
response={
|
| 256 |
"modalities": ["text"],
|
| 257 |
"instructions": "You MUST respond with function calls only - no speech or text. Choose appropriate actions for idle behavior.",
|
| 258 |
"tool_choice": "required",
|
| 259 |
-
}
|
| 260 |
)
|
|
|
|
| 2 |
import base64
|
| 3 |
import asyncio
|
| 4 |
import logging
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
from datetime import datetime
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import gradio as gr
|
| 10 |
from openai import AsyncOpenAI
|
| 11 |
from fastrtc import AdditionalOutputs, AsyncStreamHandler, wait_for_item
|
| 12 |
+
from numpy.typing import NDArray
|
| 13 |
|
| 14 |
from reachy_mini_conversation_demo.tools import (
|
| 15 |
ALL_TOOL_SPECS,
|
|
|
|
| 35 |
)
|
| 36 |
self.deps = deps
|
| 37 |
|
| 38 |
+
self.connection: Any | None = None
|
| 39 |
+
self.output_queue: "asyncio.Queue[Tuple[int, NDArray[np.int16]] | AdditionalOutputs]" = asyncio.Queue()
|
| 40 |
|
| 41 |
self.last_activity_time = asyncio.get_event_loop().time()
|
| 42 |
self.start_time = asyncio.get_event_loop().time()
|
| 43 |
self.is_idle_tool_call = False
|
| 44 |
|
| 45 |
+
def copy(self) -> "OpenaiRealtimeHandler":
|
| 46 |
"""Create a copy of the handler."""
|
| 47 |
return OpenaiRealtimeHandler(self.deps)
|
| 48 |
|
| 49 |
+
async def start_up(self) -> None:
|
| 50 |
"""Start the handler."""
|
| 51 |
self.client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
|
| 52 |
async with self.client.beta.realtime.connect(model=config.MODEL_NAME) as conn:
|
|
|
|
| 61 |
},
|
| 62 |
"voice": "ballad",
|
| 63 |
"instructions": SESSION_INSTRUCTIONS,
|
| 64 |
+
"tools": ALL_TOOL_SPECS, # type: ignore[typeddict-item]
|
| 65 |
"tool_choice": "auto",
|
| 66 |
"temperature": 0.7,
|
| 67 |
+
},
|
| 68 |
)
|
| 69 |
|
| 70 |
# Manage event received from the openai server
|
|
|
|
| 72 |
async for event in self.connection:
|
| 73 |
logger.debug(f"OpenAI event: {event.type}")
|
| 74 |
if event.type == "input_audio_buffer.speech_started":
|
| 75 |
+
if hasattr(self, "_clear_queue") and callable(self._clear_queue):
|
| 76 |
self._clear_queue()
|
| 77 |
+
if self.deps.head_wobbler is not None:
|
| 78 |
+
self.deps.head_wobbler.reset()
|
| 79 |
self.deps.movement_manager.set_listening(True)
|
| 80 |
logger.debug("User speech started")
|
| 81 |
|
|
|
|
| 86 |
if event.type in ("response.audio.completed", "response.completed"):
|
| 87 |
# Doesn't seem to be called
|
| 88 |
logger.debug("response completed")
|
| 89 |
+
if self.deps.head_wobbler is not None:
|
| 90 |
+
self.deps.head_wobbler.reset()
|
| 91 |
|
| 92 |
if event.type == "response.created":
|
| 93 |
logger.debug("Response created")
|
|
|
|
| 95 |
if event.type == "response.done":
|
| 96 |
# Doesn't mean the audio is done playing
|
| 97 |
logger.debug("Response done")
|
|
|
|
| 98 |
|
| 99 |
if event.type == "conversation.item.input_audio_transcription.completed":
|
| 100 |
logger.debug(f"User transcript: {event.transcript}")
|
|
|
|
| 105 |
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": event.transcript}))
|
| 106 |
|
| 107 |
if event.type == "response.audio.delta":
|
| 108 |
+
if self.deps.head_wobbler is not None:
|
| 109 |
+
self.deps.head_wobbler.feed(event.delta)
|
| 110 |
self.last_activity_time = asyncio.get_event_loop().time()
|
| 111 |
logger.debug("last activity time updated to %s", self.last_activity_time)
|
| 112 |
await self.output_queue.put(
|
|
|
|
| 122 |
args_json_str = getattr(event, "arguments", None)
|
| 123 |
call_id = getattr(event, "call_id", None)
|
| 124 |
|
| 125 |
+
if not isinstance(tool_name, str) or not isinstance(args_json_str, str):
|
| 126 |
+
logger.error("Invalid tool call: tool_name=%s, args=%s", tool_name, args_json_str)
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
try:
|
| 130 |
tool_result = await dispatch_tool_call(tool_name, args_json_str, self.deps)
|
| 131 |
logger.debug("Tool '%s' executed successfully", tool_name)
|
|
|
|
| 135 |
tool_result = {"error": str(e)}
|
| 136 |
|
| 137 |
# send the tool result back
|
| 138 |
+
if isinstance(call_id, str):
|
| 139 |
+
await self.connection.conversation.item.create(
|
| 140 |
+
item={
|
| 141 |
+
"type": "function_call_output",
|
| 142 |
+
"call_id": call_id,
|
| 143 |
+
"output": json.dumps(tool_result),
|
| 144 |
+
},
|
| 145 |
+
)
|
| 146 |
|
| 147 |
await self.output_queue.put(
|
| 148 |
AdditionalOutputs(
|
| 149 |
{
|
| 150 |
"role": "assistant",
|
| 151 |
"content": json.dumps(tool_result),
|
| 152 |
+
"metadata": {"title": f"🛠️ Used tool {tool_name}", "status": "done"},
|
| 153 |
},
|
| 154 |
+
),
|
| 155 |
)
|
| 156 |
|
| 157 |
if tool_name == "camera" and "b64_im" in tool_result:
|
|
|
|
| 166 |
"role": "user",
|
| 167 |
"content": [
|
| 168 |
{
|
| 169 |
+
"type": "input_image", # type: ignore[typeddict-item]
|
| 170 |
"image_url": f"data:image/jpeg;base64,{b64_im}",
|
| 171 |
+
},
|
| 172 |
],
|
| 173 |
+
},
|
| 174 |
)
|
| 175 |
logger.info("Added camera image to conversation")
|
| 176 |
|
| 177 |
+
if self.deps.camera_worker is not None:
|
| 178 |
+
np_img = self.deps.camera_worker.get_latest_frame()
|
| 179 |
+
img = gr.Image(value=np_img)
|
| 180 |
|
| 181 |
+
await self.output_queue.put(
|
| 182 |
+
AdditionalOutputs(
|
| 183 |
+
{
|
| 184 |
+
"role": "assistant",
|
| 185 |
+
"content": img,
|
| 186 |
+
},
|
| 187 |
+
),
|
| 188 |
)
|
|
|
|
| 189 |
|
| 190 |
if not self.is_idle_tool_call:
|
| 191 |
await self.connection.response.create(
|
| 192 |
response={
|
| 193 |
+
"instructions": "Use the tool result just returned and answer concisely in speech.",
|
| 194 |
+
},
|
| 195 |
)
|
| 196 |
else:
|
| 197 |
self.is_idle_tool_call = False
|
| 198 |
|
| 199 |
# re synchronize the head wobble after a tool call that may have taken some time
|
| 200 |
+
if self.deps.head_wobbler is not None:
|
| 201 |
+
self.deps.head_wobbler.reset()
|
| 202 |
|
| 203 |
# server error
|
| 204 |
if event.type == "error":
|
|
|
|
| 208 |
await self.output_queue.put(AdditionalOutputs({"role": "assistant", "content": f"[error] {msg}"}))
|
| 209 |
|
| 210 |
# Microphone receive
|
| 211 |
+
async def receive(self, frame: Tuple[int, NDArray[np.int16]]) -> None:
|
| 212 |
"""Receive audio frame from the microphone and send it to the openai server."""
|
| 213 |
if not self.connection:
|
| 214 |
return
|
|
|
|
| 216 |
array = array.squeeze()
|
| 217 |
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
|
| 218 |
# Fills the input audio buffer to be sent to the server
|
| 219 |
+
await self.connection.input_audio_buffer.append(audio=audio_message)
|
| 220 |
|
| 221 |
+
async def emit(self) -> Tuple[int, NDArray[np.int16]] | AdditionalOutputs | None:
|
| 222 |
"""Emit audio frame to be played by the speaker."""
|
| 223 |
# sends to the stream the stuff put in the output queue by the openai event handler
|
| 224 |
# This is called periodically by the fastrtc Stream
|
|
|
|
| 230 |
|
| 231 |
self.last_activity_time = asyncio.get_event_loop().time() # avoid repeated resets
|
| 232 |
|
| 233 |
+
return await wait_for_item(self.output_queue) # type: ignore[no-any-return]
|
| 234 |
|
| 235 |
async def shutdown(self) -> None:
|
| 236 |
"""Shutdown the handler."""
|
|
|
|
| 238 |
await self.connection.close()
|
| 239 |
self.connection = None
|
| 240 |
|
| 241 |
+
def format_timestamp(self) -> str:
|
| 242 |
"""Format current timestamp with date, time and elapsed seconds."""
|
| 243 |
current_time = asyncio.get_event_loop().time()
|
| 244 |
elapsed_seconds = current_time - self.start_time
|
|
|
|
| 247 |
|
| 248 |
|
| 249 |
|
| 250 |
+
async def send_idle_signal(self, idle_duration: float) -> None:
|
| 251 |
"""Send an idle signal to the openai server."""
|
| 252 |
logger.debug("Sending idle signal")
|
| 253 |
self.is_idle_tool_call = True
|
|
|
|
| 260 |
"type": "message",
|
| 261 |
"role": "user",
|
| 262 |
"content": [{"type": "input_text", "text": timestamp_msg}],
|
| 263 |
+
},
|
| 264 |
)
|
| 265 |
await self.connection.response.create(
|
| 266 |
response={
|
| 267 |
"modalities": ["text"],
|
| 268 |
"instructions": "You MUST respond with function calls only - no speech or text. Choose appropriate actions for idle behavior.",
|
| 269 |
"tool_choice": "required",
|
| 270 |
+
},
|
| 271 |
)
|
src/reachy_mini_conversation_demo/tools.py
CHANGED
|
@@ -4,7 +4,7 @@ import json
|
|
| 4 |
import asyncio
|
| 5 |
import inspect
|
| 6 |
import logging
|
| 7 |
-
from typing import Any, Dict,
|
| 8 |
from dataclasses import dataclass
|
| 9 |
|
| 10 |
from reachy_mini import ReachyMini
|
|
@@ -36,9 +36,9 @@ except ImportError as e:
|
|
| 36 |
EMOTION_AVAILABLE = False
|
| 37 |
|
| 38 |
|
| 39 |
-
def get_concrete_subclasses(base):
|
| 40 |
"""Recursively find all concrete (non-abstract) subclasses of a base class."""
|
| 41 |
-
result = []
|
| 42 |
for cls in base.__subclasses__():
|
| 43 |
if not inspect.isabstract(cls):
|
| 44 |
result.append(cls)
|
|
@@ -58,9 +58,9 @@ class ToolDependencies:
|
|
| 58 |
reachy_mini: ReachyMini
|
| 59 |
movement_manager: Any # MovementManager from moves.py
|
| 60 |
# Optional deps
|
| 61 |
-
camera_worker:
|
| 62 |
-
vision_manager:
|
| 63 |
-
head_wobbler:
|
| 64 |
motion_duration_s: float = 1.0
|
| 65 |
|
| 66 |
|
|
@@ -88,7 +88,7 @@ class Tool(abc.ABC):
|
|
| 88 |
}
|
| 89 |
|
| 90 |
@abc.abstractmethod
|
| 91 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 92 |
"""Async tool execution entrypoint."""
|
| 93 |
raise NotImplementedError
|
| 94 |
|
|
@@ -113,7 +113,7 @@ class MoveHead(Tool):
|
|
| 113 |
}
|
| 114 |
|
| 115 |
# mapping: direction -> args for create_head_pose
|
| 116 |
-
DELTAS:
|
| 117 |
"left": (0, 0, 0, 0, 0, 40),
|
| 118 |
"right": (0, 0, 0, 0, 0, -40),
|
| 119 |
"up": (0, 0, 0, 0, -30, 0),
|
|
@@ -121,9 +121,12 @@ class MoveHead(Tool):
|
|
| 121 |
"front": (0, 0, 0, 0, 0, 0),
|
| 122 |
}
|
| 123 |
|
| 124 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 125 |
"""Move head in a given direction."""
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
logger.info("Tool call: move_head direction=%s", direction)
|
| 128 |
|
| 129 |
deltas = self.DELTAS.get(direction, self.DELTAS["front"])
|
|
@@ -177,7 +180,7 @@ class Camera(Tool):
|
|
| 177 |
"required": ["question"],
|
| 178 |
}
|
| 179 |
|
| 180 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 181 |
"""Take a picture with the camera and ask a question about it."""
|
| 182 |
image_query = (kwargs.get("question") or "").strip()
|
| 183 |
if not image_query:
|
|
@@ -199,7 +202,7 @@ class Camera(Tool):
|
|
| 199 |
# Use vision manager for processing if available
|
| 200 |
if deps.vision_manager is not None:
|
| 201 |
vision_result = await asyncio.to_thread(
|
| 202 |
-
deps.vision_manager.processor.process_image, frame, image_query
|
| 203 |
)
|
| 204 |
if isinstance(vision_result, dict) and "error" in vision_result:
|
| 205 |
return vision_result
|
|
@@ -208,17 +211,16 @@ class Camera(Tool):
|
|
| 208 |
if isinstance(vision_result, str)
|
| 209 |
else {"error": "vision returned non-string"}
|
| 210 |
)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
import base64
|
| 214 |
|
| 215 |
-
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
|
| 223 |
|
| 224 |
class HeadTracking(Tool):
|
|
@@ -232,7 +234,7 @@ class HeadTracking(Tool):
|
|
| 232 |
"required": ["start"],
|
| 233 |
}
|
| 234 |
|
| 235 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 236 |
"""Enable or disable head tracking."""
|
| 237 |
enable = bool(kwargs.get("start"))
|
| 238 |
|
|
@@ -288,12 +290,12 @@ class Dance(Tool):
|
|
| 288 |
"required": [],
|
| 289 |
}
|
| 290 |
|
| 291 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 292 |
"""Play a named or random dance move once (or repeat). Non-blocking."""
|
| 293 |
if not DANCE_AVAILABLE:
|
| 294 |
return {"error": "Dance system not available"}
|
| 295 |
|
| 296 |
-
move_name = kwargs.get("move"
|
| 297 |
repeat = int(kwargs.get("repeat", 1))
|
| 298 |
|
| 299 |
logger.info("Tool call: dance move=%s repeat=%d", move_name, repeat)
|
|
@@ -326,12 +328,12 @@ class StopDance(Tool):
|
|
| 326 |
"dummy": {
|
| 327 |
"type": "boolean",
|
| 328 |
"description": "dummy boolean, set it to true",
|
| 329 |
-
}
|
| 330 |
},
|
| 331 |
"required": ["dummy"],
|
| 332 |
}
|
| 333 |
|
| 334 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 335 |
"""Stop the current dance move."""
|
| 336 |
logger.info("Tool call: stop_dance")
|
| 337 |
movement_manager = deps.movement_manager
|
|
@@ -373,7 +375,7 @@ class PlayEmotion(Tool):
|
|
| 373 |
"required": ["emotion"],
|
| 374 |
}
|
| 375 |
|
| 376 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 377 |
"""Play a pre-recorded emotion."""
|
| 378 |
if not EMOTION_AVAILABLE:
|
| 379 |
return {"error": "Emotion system not available"}
|
|
@@ -399,7 +401,7 @@ class PlayEmotion(Tool):
|
|
| 399 |
|
| 400 |
except Exception as e:
|
| 401 |
logger.exception("Failed to play emotion")
|
| 402 |
-
return {"error": f"Failed to play emotion: {
|
| 403 |
|
| 404 |
|
| 405 |
class StopEmotion(Tool):
|
|
@@ -413,12 +415,12 @@ class StopEmotion(Tool):
|
|
| 413 |
"dummy": {
|
| 414 |
"type": "boolean",
|
| 415 |
"description": "dummy boolean, set it to true",
|
| 416 |
-
}
|
| 417 |
},
|
| 418 |
"required": ["dummy"],
|
| 419 |
}
|
| 420 |
|
| 421 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 422 |
"""Stop the current emotion."""
|
| 423 |
logger.info("Tool call: stop_emotion")
|
| 424 |
movement_manager = deps.movement_manager
|
|
@@ -442,7 +444,7 @@ class DoNothing(Tool):
|
|
| 442 |
"required": [],
|
| 443 |
}
|
| 444 |
|
| 445 |
-
async def __call__(self, deps: ToolDependencies, **kwargs) -> Dict[str, Any]:
|
| 446 |
"""Do nothing - stay still and silent."""
|
| 447 |
reason = kwargs.get("reason", "just chilling")
|
| 448 |
logger.info("Tool call: do_nothing reason=%s", reason)
|
|
@@ -452,12 +454,12 @@ class DoNothing(Tool):
|
|
| 452 |
# Registry & specs (dynamic)
|
| 453 |
|
| 454 |
# List of available tool classes
|
| 455 |
-
ALL_TOOLS: Dict[str, Tool] = {cls.name: cls() for cls in get_concrete_subclasses(Tool)}
|
| 456 |
ALL_TOOL_SPECS = [tool.spec() for tool in ALL_TOOLS.values()]
|
| 457 |
|
| 458 |
|
| 459 |
# Dispatcher
|
| 460 |
-
def _safe_load_obj(args_json: str) ->
|
| 461 |
try:
|
| 462 |
parsed_args = json.loads(args_json or "{}")
|
| 463 |
return parsed_args if isinstance(parsed_args, dict) else {}
|
|
|
|
| 4 |
import asyncio
|
| 5 |
import inspect
|
| 6 |
import logging
|
| 7 |
+
from typing import Any, Dict, List, Tuple, Literal
|
| 8 |
from dataclasses import dataclass
|
| 9 |
|
| 10 |
from reachy_mini import ReachyMini
|
|
|
|
| 36 |
EMOTION_AVAILABLE = False
|
| 37 |
|
| 38 |
|
| 39 |
+
def get_concrete_subclasses(base: type[Tool]) -> List[type[Tool]]:
|
| 40 |
"""Recursively find all concrete (non-abstract) subclasses of a base class."""
|
| 41 |
+
result: List[type[Tool]] = []
|
| 42 |
for cls in base.__subclasses__():
|
| 43 |
if not inspect.isabstract(cls):
|
| 44 |
result.append(cls)
|
|
|
|
| 58 |
reachy_mini: ReachyMini
|
| 59 |
movement_manager: Any # MovementManager from moves.py
|
| 60 |
# Optional deps
|
| 61 |
+
camera_worker: Any | None = None # CameraWorker for frame buffering
|
| 62 |
+
vision_manager: Any | None = None
|
| 63 |
+
head_wobbler: Any | None = None # HeadWobbler for audio-reactive motion
|
| 64 |
motion_duration_s: float = 1.0
|
| 65 |
|
| 66 |
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
@abc.abstractmethod
|
| 91 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 92 |
"""Async tool execution entrypoint."""
|
| 93 |
raise NotImplementedError
|
| 94 |
|
|
|
|
| 113 |
}
|
| 114 |
|
| 115 |
# mapping: direction -> args for create_head_pose
|
| 116 |
+
DELTAS: Dict[str, Tuple[int, int, int, int, int, int]] = {
|
| 117 |
"left": (0, 0, 0, 0, 0, 40),
|
| 118 |
"right": (0, 0, 0, 0, 0, -40),
|
| 119 |
"up": (0, 0, 0, 0, -30, 0),
|
|
|
|
| 121 |
"front": (0, 0, 0, 0, 0, 0),
|
| 122 |
}
|
| 123 |
|
| 124 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 125 |
"""Move head in a given direction."""
|
| 126 |
+
direction_raw = kwargs.get("direction")
|
| 127 |
+
if not isinstance(direction_raw, str):
|
| 128 |
+
return {"error": "direction must be a string"}
|
| 129 |
+
direction: Direction = direction_raw # type: ignore[assignment]
|
| 130 |
logger.info("Tool call: move_head direction=%s", direction)
|
| 131 |
|
| 132 |
deltas = self.DELTAS.get(direction, self.DELTAS["front"])
|
|
|
|
| 180 |
"required": ["question"],
|
| 181 |
}
|
| 182 |
|
| 183 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 184 |
"""Take a picture with the camera and ask a question about it."""
|
| 185 |
image_query = (kwargs.get("question") or "").strip()
|
| 186 |
if not image_query:
|
|
|
|
| 202 |
# Use vision manager for processing if available
|
| 203 |
if deps.vision_manager is not None:
|
| 204 |
vision_result = await asyncio.to_thread(
|
| 205 |
+
deps.vision_manager.processor.process_image, frame, image_query,
|
| 206 |
)
|
| 207 |
if isinstance(vision_result, dict) and "error" in vision_result:
|
| 208 |
return vision_result
|
|
|
|
| 211 |
if isinstance(vision_result, str)
|
| 212 |
else {"error": "vision returned non-string"}
|
| 213 |
)
|
| 214 |
+
# Return base64 encoded image like main_works.py camera tool
|
| 215 |
+
import base64
|
|
|
|
| 216 |
|
| 217 |
+
import cv2
|
| 218 |
|
| 219 |
+
temp_path = "/tmp/camera_frame.jpg"
|
| 220 |
+
cv2.imwrite(temp_path, frame)
|
| 221 |
+
with open(temp_path, "rb") as f:
|
| 222 |
+
b64_encoded = base64.b64encode(f.read()).decode("utf-8")
|
| 223 |
+
return {"b64_im": b64_encoded}
|
| 224 |
|
| 225 |
|
| 226 |
class HeadTracking(Tool):
|
|
|
|
| 234 |
"required": ["start"],
|
| 235 |
}
|
| 236 |
|
| 237 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 238 |
"""Enable or disable head tracking."""
|
| 239 |
enable = bool(kwargs.get("start"))
|
| 240 |
|
|
|
|
| 290 |
"required": [],
|
| 291 |
}
|
| 292 |
|
| 293 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 294 |
"""Play a named or random dance move once (or repeat). Non-blocking."""
|
| 295 |
if not DANCE_AVAILABLE:
|
| 296 |
return {"error": "Dance system not available"}
|
| 297 |
|
| 298 |
+
move_name = kwargs.get("move")
|
| 299 |
repeat = int(kwargs.get("repeat", 1))
|
| 300 |
|
| 301 |
logger.info("Tool call: dance move=%s repeat=%d", move_name, repeat)
|
|
|
|
| 328 |
"dummy": {
|
| 329 |
"type": "boolean",
|
| 330 |
"description": "dummy boolean, set it to true",
|
| 331 |
+
},
|
| 332 |
},
|
| 333 |
"required": ["dummy"],
|
| 334 |
}
|
| 335 |
|
| 336 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 337 |
"""Stop the current dance move."""
|
| 338 |
logger.info("Tool call: stop_dance")
|
| 339 |
movement_manager = deps.movement_manager
|
|
|
|
| 375 |
"required": ["emotion"],
|
| 376 |
}
|
| 377 |
|
| 378 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 379 |
"""Play a pre-recorded emotion."""
|
| 380 |
if not EMOTION_AVAILABLE:
|
| 381 |
return {"error": "Emotion system not available"}
|
|
|
|
| 401 |
|
| 402 |
except Exception as e:
|
| 403 |
logger.exception("Failed to play emotion")
|
| 404 |
+
return {"error": f"Failed to play emotion: {e!s}"}
|
| 405 |
|
| 406 |
|
| 407 |
class StopEmotion(Tool):
|
|
|
|
| 415 |
"dummy": {
|
| 416 |
"type": "boolean",
|
| 417 |
"description": "dummy boolean, set it to true",
|
| 418 |
+
},
|
| 419 |
},
|
| 420 |
"required": ["dummy"],
|
| 421 |
}
|
| 422 |
|
| 423 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 424 |
"""Stop the current emotion."""
|
| 425 |
logger.info("Tool call: stop_emotion")
|
| 426 |
movement_manager = deps.movement_manager
|
|
|
|
| 444 |
"required": [],
|
| 445 |
}
|
| 446 |
|
| 447 |
+
async def __call__(self, deps: ToolDependencies, **kwargs: Any) -> Dict[str, Any]:
|
| 448 |
"""Do nothing - stay still and silent."""
|
| 449 |
reason = kwargs.get("reason", "just chilling")
|
| 450 |
logger.info("Tool call: do_nothing reason=%s", reason)
|
|
|
|
| 454 |
# Registry & specs (dynamic)
|
| 455 |
|
| 456 |
# List of available tool classes
|
| 457 |
+
ALL_TOOLS: Dict[str, Tool] = {cls.name: cls() for cls in get_concrete_subclasses(Tool)} # type: ignore[type-abstract]
|
| 458 |
ALL_TOOL_SPECS = [tool.spec() for tool in ALL_TOOLS.values()]
|
| 459 |
|
| 460 |
|
| 461 |
# Dispatcher
|
| 462 |
+
def _safe_load_obj(args_json: str) -> Dict[str, Any]:
|
| 463 |
try:
|
| 464 |
parsed_args = json.loads(args_json or "{}")
|
| 465 |
return parsed_args if isinstance(parsed_args, dict) else {}
|
src/reachy_mini_conversation_demo/utils.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
import logging
|
| 2 |
import argparse
|
| 3 |
import warnings
|
|
|
|
| 4 |
|
|
|
|
| 5 |
from reachy_mini_conversation_demo.camera_worker import CameraWorker
|
| 6 |
|
| 7 |
|
| 8 |
-
def parse_args():
|
| 9 |
"""Parse command line arguments."""
|
| 10 |
parser = argparse.ArgumentParser("Reachy Mini Conversation Demo")
|
| 11 |
parser.add_argument(
|
|
@@ -26,7 +28,7 @@ def parse_args():
|
|
| 26 |
return parser.parse_args()
|
| 27 |
|
| 28 |
|
| 29 |
-
def handle_vision_stuff(args, current_robot):
|
| 30 |
"""Initialize camera, head tracker, camera worker, and vision manager.
|
| 31 |
|
| 32 |
By default, vision is handled by gpt-realtime model when camera tool is used.
|
|
@@ -44,7 +46,7 @@ def handle_vision_stuff(args, current_robot):
|
|
| 44 |
|
| 45 |
head_tracker = HeadTracker()
|
| 46 |
elif args.head_tracker == "mediapipe":
|
| 47 |
-
from reachy_mini_toolbox.vision import HeadTracker
|
| 48 |
|
| 49 |
head_tracker = HeadTracker()
|
| 50 |
|
|
@@ -59,17 +61,17 @@ def handle_vision_stuff(args, current_robot):
|
|
| 59 |
vision_manager = initialize_vision_manager(camera_worker)
|
| 60 |
except ImportError as e:
|
| 61 |
raise ImportError(
|
| 62 |
-
"To use --local-vision, please install the extra dependencies: pip install '.[local_vision]'"
|
| 63 |
) from e
|
| 64 |
else:
|
| 65 |
logging.getLogger(__name__).info(
|
| 66 |
-
"Using gpt-realtime for vision (default). Use --local-vision for local processing."
|
| 67 |
)
|
| 68 |
|
| 69 |
return camera_worker, head_tracker, vision_manager
|
| 70 |
|
| 71 |
|
| 72 |
-
def setup_logger(debug):
|
| 73 |
"""Setups the logger."""
|
| 74 |
log_level = "DEBUG" if debug else "INFO"
|
| 75 |
logging.basicConfig(
|
|
|
|
| 1 |
import logging
|
| 2 |
import argparse
|
| 3 |
import warnings
|
| 4 |
+
from typing import Any, Tuple
|
| 5 |
|
| 6 |
+
from reachy_mini import ReachyMini
|
| 7 |
from reachy_mini_conversation_demo.camera_worker import CameraWorker
|
| 8 |
|
| 9 |
|
| 10 |
+
def parse_args() -> argparse.Namespace:
|
| 11 |
"""Parse command line arguments."""
|
| 12 |
parser = argparse.ArgumentParser("Reachy Mini Conversation Demo")
|
| 13 |
parser.add_argument(
|
|
|
|
| 28 |
return parser.parse_args()
|
| 29 |
|
| 30 |
|
| 31 |
+
def handle_vision_stuff(args: argparse.Namespace, current_robot: ReachyMini) -> Tuple[CameraWorker | None, Any, Any]:
|
| 32 |
"""Initialize camera, head tracker, camera worker, and vision manager.
|
| 33 |
|
| 34 |
By default, vision is handled by gpt-realtime model when camera tool is used.
|
|
|
|
| 46 |
|
| 47 |
head_tracker = HeadTracker()
|
| 48 |
elif args.head_tracker == "mediapipe":
|
| 49 |
+
from reachy_mini_toolbox.vision import HeadTracker # type: ignore[no-redef]
|
| 50 |
|
| 51 |
head_tracker = HeadTracker()
|
| 52 |
|
|
|
|
| 61 |
vision_manager = initialize_vision_manager(camera_worker)
|
| 62 |
except ImportError as e:
|
| 63 |
raise ImportError(
|
| 64 |
+
"To use --local-vision, please install the extra dependencies: pip install '.[local_vision]'",
|
| 65 |
) from e
|
| 66 |
else:
|
| 67 |
logging.getLogger(__name__).info(
|
| 68 |
+
"Using gpt-realtime for vision (default). Use --local-vision for local processing.",
|
| 69 |
)
|
| 70 |
|
| 71 |
return camera_worker, head_tracker, vision_manager
|
| 72 |
|
| 73 |
|
| 74 |
+
def setup_logger(debug: bool) -> logging.Logger:
|
| 75 |
"""Setups the logger."""
|
| 76 |
log_level = "DEBUG" if debug else "INFO"
|
| 77 |
logging.basicConfig(
|
src/reachy_mini_conversation_demo/vision/processors.py
CHANGED
|
@@ -3,12 +3,13 @@ import time
|
|
| 3 |
import base64
|
| 4 |
import logging
|
| 5 |
import threading
|
| 6 |
-
from typing import Any, Dict
|
| 7 |
from dataclasses import dataclass
|
| 8 |
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
|
|
|
| 12 |
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 13 |
from huggingface_hub import snapshot_download
|
| 14 |
|
|
@@ -34,7 +35,7 @@ class VisionConfig:
|
|
| 34 |
class VisionProcessor:
|
| 35 |
"""Handles SmolVLM2 model loading and inference."""
|
| 36 |
|
| 37 |
-
def __init__(self, vision_config: VisionConfig = None):
|
| 38 |
"""Initialize the vision processor."""
|
| 39 |
self.vision_config = vision_config or VisionConfig()
|
| 40 |
self.model_path = self.vision_config.model_path
|
|
@@ -60,7 +61,7 @@ class VisionProcessor:
|
|
| 60 |
"""Load model and processor onto the selected device."""
|
| 61 |
try:
|
| 62 |
logger.info(f"Loading SmolVLM2 model on {self.device} (HF_HOME={config.HF_HOME})")
|
| 63 |
-
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
| 64 |
|
| 65 |
# Select dtype depending on device
|
| 66 |
if self.device == "cuda":
|
|
@@ -70,16 +71,17 @@ class VisionProcessor:
|
|
| 70 |
else:
|
| 71 |
dtype = torch.float32
|
| 72 |
|
| 73 |
-
model_kwargs = {"dtype": dtype}
|
| 74 |
|
| 75 |
# flash_attention_2 is CUDA-only; skip on MPS/CPU
|
| 76 |
if self.device == "cuda":
|
| 77 |
model_kwargs["_attn_implementation"] = "flash_attention_2"
|
| 78 |
|
| 79 |
# Load model weights
|
| 80 |
-
self.model = AutoModelForImageTextToText.from_pretrained(self.model_path, **model_kwargs).to(self.device)
|
| 81 |
|
| 82 |
-
self.model
|
|
|
|
| 83 |
self._initialized = True
|
| 84 |
return True
|
| 85 |
|
|
@@ -89,11 +91,11 @@ class VisionProcessor:
|
|
| 89 |
|
| 90 |
def process_image(
|
| 91 |
self,
|
| 92 |
-
cv2_image: np.
|
| 93 |
prompt: str = "Briefly describe what you see in one sentence.",
|
| 94 |
) -> str:
|
| 95 |
"""Process CV2 image and return description with retry logic."""
|
| 96 |
-
if not self._initialized:
|
| 97 |
return "Vision model not initialized"
|
| 98 |
|
| 99 |
for attempt in range(self.vision_config.max_retries):
|
|
@@ -205,16 +207,16 @@ class VisionProcessor:
|
|
| 205 |
class VisionManager:
|
| 206 |
"""Manages periodic vision processing and scene understanding."""
|
| 207 |
|
| 208 |
-
def __init__(self, camera, vision_config: VisionConfig = None):
|
| 209 |
"""Initialize vision manager with camera and configuration."""
|
| 210 |
self.camera = camera
|
| 211 |
self.vision_config = vision_config or VisionConfig()
|
| 212 |
self.vision_interval = self.vision_config.vision_interval
|
| 213 |
self.processor = VisionProcessor(self.vision_config)
|
| 214 |
|
| 215 |
-
self._last_processed_time = 0
|
| 216 |
self._stop_event = threading.Event()
|
| 217 |
-
self._thread:
|
| 218 |
|
| 219 |
# Initialize processor
|
| 220 |
if not self.processor.initialize():
|
|
@@ -245,7 +247,7 @@ class VisionManager:
|
|
| 245 |
frame = self.camera.get_latest_frame()
|
| 246 |
if frame is not None:
|
| 247 |
description = self.processor.process_image(
|
| 248 |
-
frame, "Briefly describe what you see in one sentence."
|
| 249 |
)
|
| 250 |
|
| 251 |
# Only update if we got a valid response
|
|
@@ -274,7 +276,7 @@ class VisionManager:
|
|
| 274 |
}
|
| 275 |
|
| 276 |
|
| 277 |
-
def initialize_vision_manager(camera_worker) ->
|
| 278 |
"""Initialize vision manager with model download and configuration.
|
| 279 |
|
| 280 |
Args:
|
|
@@ -318,7 +320,7 @@ def initialize_vision_manager(camera_worker) -> Optional[VisionManager]:
|
|
| 318 |
# Log device info
|
| 319 |
device_info = vision_manager.processor.get_model_info()
|
| 320 |
logger.info(
|
| 321 |
-
f"Vision processing enabled: {device_info.get('model_path')} on {device_info.get('device')}"
|
| 322 |
)
|
| 323 |
|
| 324 |
return vision_manager
|
|
|
|
| 3 |
import base64
|
| 4 |
import logging
|
| 5 |
import threading
|
| 6 |
+
from typing import Any, Dict
|
| 7 |
from dataclasses import dataclass
|
| 8 |
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
+
from numpy.typing import NDArray
|
| 13 |
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 14 |
from huggingface_hub import snapshot_download
|
| 15 |
|
|
|
|
| 35 |
class VisionProcessor:
|
| 36 |
"""Handles SmolVLM2 model loading and inference."""
|
| 37 |
|
| 38 |
+
def __init__(self, vision_config: VisionConfig | None = None):
|
| 39 |
"""Initialize the vision processor."""
|
| 40 |
self.vision_config = vision_config or VisionConfig()
|
| 41 |
self.model_path = self.vision_config.model_path
|
|
|
|
| 61 |
"""Load model and processor onto the selected device."""
|
| 62 |
try:
|
| 63 |
logger.info(f"Loading SmolVLM2 model on {self.device} (HF_HOME={config.HF_HOME})")
|
| 64 |
+
self.processor = AutoProcessor.from_pretrained(self.model_path) # type: ignore[no-untyped-call]
|
| 65 |
|
| 66 |
# Select dtype depending on device
|
| 67 |
if self.device == "cuda":
|
|
|
|
| 71 |
else:
|
| 72 |
dtype = torch.float32
|
| 73 |
|
| 74 |
+
model_kwargs: Dict[str, Any] = {"dtype": dtype}
|
| 75 |
|
| 76 |
# flash_attention_2 is CUDA-only; skip on MPS/CPU
|
| 77 |
if self.device == "cuda":
|
| 78 |
model_kwargs["_attn_implementation"] = "flash_attention_2"
|
| 79 |
|
| 80 |
# Load model weights
|
| 81 |
+
self.model = AutoModelForImageTextToText.from_pretrained(self.model_path, **model_kwargs).to(self.device) # type: ignore[arg-type]
|
| 82 |
|
| 83 |
+
if self.model is not None:
|
| 84 |
+
self.model.eval()
|
| 85 |
self._initialized = True
|
| 86 |
return True
|
| 87 |
|
|
|
|
| 91 |
|
| 92 |
def process_image(
|
| 93 |
self,
|
| 94 |
+
cv2_image: NDArray[np.uint8],
|
| 95 |
prompt: str = "Briefly describe what you see in one sentence.",
|
| 96 |
) -> str:
|
| 97 |
"""Process CV2 image and return description with retry logic."""
|
| 98 |
+
if not self._initialized or self.processor is None or self.model is None:
|
| 99 |
return "Vision model not initialized"
|
| 100 |
|
| 101 |
for attempt in range(self.vision_config.max_retries):
|
|
|
|
| 207 |
class VisionManager:
|
| 208 |
"""Manages periodic vision processing and scene understanding."""
|
| 209 |
|
| 210 |
+
def __init__(self, camera: Any, vision_config: VisionConfig | None = None):
|
| 211 |
"""Initialize vision manager with camera and configuration."""
|
| 212 |
self.camera = camera
|
| 213 |
self.vision_config = vision_config or VisionConfig()
|
| 214 |
self.vision_interval = self.vision_config.vision_interval
|
| 215 |
self.processor = VisionProcessor(self.vision_config)
|
| 216 |
|
| 217 |
+
self._last_processed_time = 0.0
|
| 218 |
self._stop_event = threading.Event()
|
| 219 |
+
self._thread: threading.Thread | None = None
|
| 220 |
|
| 221 |
# Initialize processor
|
| 222 |
if not self.processor.initialize():
|
|
|
|
| 247 |
frame = self.camera.get_latest_frame()
|
| 248 |
if frame is not None:
|
| 249 |
description = self.processor.process_image(
|
| 250 |
+
frame, "Briefly describe what you see in one sentence.",
|
| 251 |
)
|
| 252 |
|
| 253 |
# Only update if we got a valid response
|
|
|
|
| 276 |
}
|
| 277 |
|
| 278 |
|
| 279 |
+
def initialize_vision_manager(camera_worker: Any) -> VisionManager | None:
|
| 280 |
"""Initialize vision manager with model download and configuration.
|
| 281 |
|
| 282 |
Args:
|
|
|
|
| 320 |
# Log device info
|
| 321 |
device_info = vision_manager.processor.get_model_info()
|
| 322 |
logger.info(
|
| 323 |
+
f"Vision processing enabled: {device_info.get('model_path')} on {device_info.get('device')}",
|
| 324 |
)
|
| 325 |
|
| 326 |
return vision_manager
|
src/reachy_mini_conversation_demo/vision/yolo_head_tracker.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import logging
|
| 3 |
-
from typing import Tuple
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
try:
|
| 9 |
from supervision import Detections
|
| 10 |
-
from ultralytics import YOLO
|
| 11 |
except ImportError as e:
|
| 12 |
raise ImportError(
|
| 13 |
-
"To use YOLO head tracker, please install the extra dependencies: pip install '.[yolo_vision]'"
|
| 14 |
) from e
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
|
|
@@ -48,7 +49,7 @@ class HeadTracker:
|
|
| 48 |
logger.error(f"Failed to load YOLO model: {e}")
|
| 49 |
raise
|
| 50 |
|
| 51 |
-
def _select_best_face(self, detections: Detections) ->
|
| 52 |
"""Select the best face based on confidence and area (largest face with highest confidence).
|
| 53 |
|
| 54 |
Args:
|
|
@@ -61,6 +62,10 @@ class HeadTracker:
|
|
| 61 |
if detections.xyxy.shape[0] == 0:
|
| 62 |
return None
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# Filter by confidence threshold
|
| 65 |
valid_mask = detections.confidence >= self.confidence_threshold
|
| 66 |
if not np.any(valid_mask):
|
|
@@ -78,9 +83,9 @@ class HeadTracker:
|
|
| 78 |
|
| 79 |
# Return index of best face
|
| 80 |
best_idx = valid_indices[np.argmax(scores)]
|
| 81 |
-
return best_idx
|
| 82 |
|
| 83 |
-
def _bbox_to_mp_coords(self, bbox: np.
|
| 84 |
"""Convert bounding box center to MediaPipe-style coordinates [-1, 1].
|
| 85 |
|
| 86 |
Args:
|
|
@@ -101,7 +106,7 @@ class HeadTracker:
|
|
| 101 |
|
| 102 |
return np.array([norm_x, norm_y], dtype=np.float32)
|
| 103 |
|
| 104 |
-
def get_head_position(self, img: np.
|
| 105 |
"""Get head position from face detection.
|
| 106 |
|
| 107 |
Args:
|
|
@@ -125,9 +130,10 @@ class HeadTracker:
|
|
| 125 |
return None, None
|
| 126 |
|
| 127 |
bbox = detections.xyxy[face_idx]
|
| 128 |
-
confidence = detections.confidence[face_idx]
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
|
| 132 |
# Get face center in [-1, 1] coordinates
|
| 133 |
face_center = self._bbox_to_mp_coords(bbox, w, h)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import logging
|
| 3 |
+
from typing import Tuple
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
+
from numpy.typing import NDArray
|
| 7 |
|
| 8 |
|
| 9 |
try:
|
| 10 |
from supervision import Detections
|
| 11 |
+
from ultralytics import YOLO # type: ignore[attr-defined]
|
| 12 |
except ImportError as e:
|
| 13 |
raise ImportError(
|
| 14 |
+
"To use YOLO head tracker, please install the extra dependencies: pip install '.[yolo_vision]'",
|
| 15 |
) from e
|
| 16 |
from huggingface_hub import hf_hub_download
|
| 17 |
|
|
|
|
| 49 |
logger.error(f"Failed to load YOLO model: {e}")
|
| 50 |
raise
|
| 51 |
|
| 52 |
+
def _select_best_face(self, detections: Detections) -> int | None:
|
| 53 |
"""Select the best face based on confidence and area (largest face with highest confidence).
|
| 54 |
|
| 55 |
Args:
|
|
|
|
| 62 |
if detections.xyxy.shape[0] == 0:
|
| 63 |
return None
|
| 64 |
|
| 65 |
+
# Check if confidence is available
|
| 66 |
+
if detections.confidence is None:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
# Filter by confidence threshold
|
| 70 |
valid_mask = detections.confidence >= self.confidence_threshold
|
| 71 |
if not np.any(valid_mask):
|
|
|
|
| 83 |
|
| 84 |
# Return index of best face
|
| 85 |
best_idx = valid_indices[np.argmax(scores)]
|
| 86 |
+
return int(best_idx)
|
| 87 |
|
| 88 |
+
def _bbox_to_mp_coords(self, bbox: NDArray[np.float32], w: int, h: int) -> NDArray[np.float32]:
|
| 89 |
"""Convert bounding box center to MediaPipe-style coordinates [-1, 1].
|
| 90 |
|
| 91 |
Args:
|
|
|
|
| 106 |
|
| 107 |
return np.array([norm_x, norm_y], dtype=np.float32)
|
| 108 |
|
| 109 |
+
def get_head_position(self, img: NDArray[np.uint8]) -> Tuple[NDArray[np.float32] | None, float | None]:
|
| 110 |
"""Get head position from face detection.
|
| 111 |
|
| 112 |
Args:
|
|
|
|
| 130 |
return None, None
|
| 131 |
|
| 132 |
bbox = detections.xyxy[face_idx]
|
|
|
|
| 133 |
|
| 134 |
+
if detections.confidence is not None:
|
| 135 |
+
confidence = detections.confidence[face_idx]
|
| 136 |
+
logger.debug(f"Face detected with confidence: {confidence:.2f}")
|
| 137 |
|
| 138 |
# Get face center in [-1, 1] coordinates
|
| 139 |
face_center = self._bbox_to_mp_coords(bbox, w, h)
|
tests/audio/test_head_wobbler.py
CHANGED
|
@@ -4,7 +4,8 @@ import math
|
|
| 4 |
import time
|
| 5 |
import base64
|
| 6 |
import threading
|
| 7 |
-
from typing import List, Tuple
|
|
|
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
|
|
@@ -74,7 +75,7 @@ def test_reset_allows_future_offsets() -> None:
|
|
| 74 |
wobbler.stop()
|
| 75 |
|
| 76 |
|
| 77 |
-
def test_reset_during_inflight_chunk_keeps_worker(monkeypatch) -> None:
|
| 78 |
"""Simulate reset during chunk processing to ensure the worker survives."""
|
| 79 |
wobbler, captured = _start_wobbler()
|
| 80 |
ready = threading.Event()
|
|
|
|
| 4 |
import time
|
| 5 |
import base64
|
| 6 |
import threading
|
| 7 |
+
from typing import Any, List, Tuple
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
|
|
|
|
| 75 |
wobbler.stop()
|
| 76 |
|
| 77 |
|
| 78 |
+
def test_reset_during_inflight_chunk_keeps_worker(monkeypatch: Any) -> None:
|
| 79 |
"""Simulate reset during chunk processing to ensure the worker survives."""
|
| 80 |
wobbler, captured = _start_wobbler()
|
| 81 |
ready = threading.Event()
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|