刘鑫 commited on
Commit
2dd4a32
·
1 Parent(s): cddb0f1

set zero gpu inference

Browse files
Files changed (1) hide show
  1. app.py +67 -15
app.py CHANGED
@@ -5,6 +5,8 @@ import gradio as gr
5
  import spaces
6
  from typing import Optional, Tuple
7
  from pathlib import Path
 
 
8
 
9
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
  if os.environ.get("HF_REPO_ID", "").strip() == "":
@@ -66,7 +68,7 @@ def get_voxcpm_model():
66
  print("Loading VoxCPM model...")
67
  model_dir = _resolve_model_dir()
68
  print(f"Using model dir: {model_dir}")
69
- _voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
70
  print("VoxCPM model loaded.")
71
  return _voxcpm_model
72
 
@@ -83,9 +85,9 @@ def prompt_wav_recognition(prompt_wav: Optional[str]) -> str:
83
 
84
 
85
  @spaces.GPU(duration=120)
86
- def generate_tts_audio(
87
  text_input: str,
88
- prompt_wav_path_input: Optional[str] = None,
89
  prompt_text_input: Optional[str] = None,
90
  cfg_value_input: float = 2.0,
91
  inference_timesteps_input: int = 10,
@@ -93,8 +95,8 @@ def generate_tts_audio(
93
  denoise: bool = True,
94
  ) -> Tuple[int, np.ndarray]:
95
  """
96
- Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
97
- Returns (sample_rate, waveform_numpy)
98
  """
99
  voxcpm_model = get_voxcpm_model()
100
 
@@ -102,20 +104,70 @@ def generate_tts_audio(
102
  if len(text) == 0:
103
  raise ValueError("Please input text to synthesize.")
104
 
105
- prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
106
  prompt_text = prompt_text_input if prompt_text_input else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- print(f"Generating audio for text: '{text[:60]}...'")
109
- wav = voxcpm_model.generate(
110
- text=text,
111
- prompt_text=prompt_text,
112
- prompt_wav_path=prompt_wav_path,
113
- cfg_value=float(cfg_value_input),
114
- inference_timesteps=int(inference_timesteps_input),
115
- normalize=do_normalize,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  denoise=denoise,
117
  )
118
- return (voxcpm_model.tts_model.sample_rate, wav)
119
 
120
 
121
  # ---------- UI Builders ----------
 
5
  import spaces
6
  from typing import Optional, Tuple
7
  from pathlib import Path
8
+ import tempfile
9
+ import soundfile as sf
10
 
11
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
  if os.environ.get("HF_REPO_ID", "").strip() == "":
 
68
  print("Loading VoxCPM model...")
69
  model_dir = _resolve_model_dir()
70
  print(f"Using model dir: {model_dir}")
71
+ _voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
72
  print("VoxCPM model loaded.")
73
  return _voxcpm_model
74
 
 
85
 
86
 
87
  @spaces.GPU(duration=120)
88
+ def generate_tts_audio_gpu(
89
  text_input: str,
90
+ prompt_wav_data: Optional[Tuple[np.ndarray, int]] = None,
91
  prompt_text_input: Optional[str] = None,
92
  cfg_value_input: float = 2.0,
93
  inference_timesteps_input: int = 10,
 
95
  denoise: bool = True,
96
  ) -> Tuple[int, np.ndarray]:
97
  """
98
+ GPU function: Generate speech from text using VoxCPM.
99
+ prompt_wav_data is (audio_array, sample_rate) tuple.
100
  """
101
  voxcpm_model = get_voxcpm_model()
102
 
 
104
  if len(text) == 0:
105
  raise ValueError("Please input text to synthesize.")
106
 
 
107
  prompt_text = prompt_text_input if prompt_text_input else None
108
+ prompt_wav_path = None
109
+
110
+ # If prompt audio data provided, write to temp file for voxcpm
111
+ if prompt_wav_data is not None:
112
+ audio_array, sr = prompt_wav_data
113
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
114
+ sf.write(f.name, audio_array, sr)
115
+ prompt_wav_path = f.name
116
+
117
+ try:
118
+ print(f"Generating audio for text: '{text[:60]}...'")
119
+ wav = voxcpm_model.generate(
120
+ text=text,
121
+ prompt_text=prompt_text,
122
+ prompt_wav_path=prompt_wav_path,
123
+ cfg_value=float(cfg_value_input),
124
+ inference_timesteps=int(inference_timesteps_input),
125
+ normalize=do_normalize,
126
+ denoise=denoise,
127
+ )
128
+ return (voxcpm_model.tts_model.sample_rate, wav)
129
+ finally:
130
+ # Cleanup temp file
131
+ if prompt_wav_path and os.path.exists(prompt_wav_path):
132
+ try:
133
+ os.unlink(prompt_wav_path)
134
+ except Exception:
135
+ pass
136
+
137
 
138
+ def generate_tts_audio(
139
+ text_input: str,
140
+ prompt_wav_path_input: Optional[str] = None,
141
+ prompt_text_input: Optional[str] = None,
142
+ cfg_value_input: float = 2.0,
143
+ inference_timesteps_input: int = 10,
144
+ do_normalize: bool = True,
145
+ denoise: bool = True,
146
+ ) -> Tuple[int, np.ndarray]:
147
+ """
148
+ Wrapper: Read audio file in CPU, then call GPU function.
149
+ """
150
+ prompt_wav_data = None
151
+
152
+ # Read audio file before entering GPU context
153
+ if prompt_wav_path_input and os.path.exists(prompt_wav_path_input):
154
+ try:
155
+ audio_array, sr = sf.read(prompt_wav_path_input, dtype='float32')
156
+ prompt_wav_data = (audio_array, sr)
157
+ print(f"Loaded prompt audio: {audio_array.shape}, sr={sr}")
158
+ except Exception as e:
159
+ print(f"Warning: Failed to load prompt audio: {e}")
160
+ prompt_wav_data = None
161
+
162
+ return generate_tts_audio_gpu(
163
+ text_input=text_input,
164
+ prompt_wav_data=prompt_wav_data,
165
+ prompt_text_input=prompt_text_input,
166
+ cfg_value_input=cfg_value_input,
167
+ inference_timesteps_input=inference_timesteps_input,
168
+ do_normalize=do_normalize,
169
  denoise=denoise,
170
  )
 
171
 
172
 
173
  # ---------- UI Builders ----------