AIvry commited on
Commit
236f287
·
verified ·
1 Parent(s): 7a64434

Upload 11 files

Browse files
Files changed (5) hide show
  1. app.py +4 -6
  2. argshield.py +2 -3
  3. config.py +33 -30
  4. engine.py +239 -194
  5. models.py +323 -359
app.py CHANGED
@@ -88,7 +88,7 @@ def process_audio_files(zip_file, model_name, layer, alpha):
88
  model_defaults = {
89
  "wavlm": 24, "wav2vec2": 24, "hubert": 24,
90
  "wavlm_base": 12, "wav2vec2_base": 12, "hubert_base": 12,
91
- "wav2vec2_xlsr": 24, "ast": 12
92
  }
93
  layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
94
 
@@ -184,7 +184,6 @@ def create_interface():
184
  | `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster, good quality |
185
  | `hubert_base` | HuBERT Base | 12 | |
186
  | `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
187
- | `ast` | Audio Spectrogram Transformer | 12 | Music |
188
 
189
  ## Parameters
190
 
@@ -233,7 +232,7 @@ def create_interface():
233
  model_dropdown = gr.Dropdown(
234
  choices=["raw", "wavlm", "wav2vec2", "hubert",
235
  "wavlm_base", "wav2vec2_base", "hubert_base",
236
- "wav2vec2_xlsr", "ast"],
237
  value="wav2vec2_base",
238
  label="Select embedding model"
239
  )
@@ -265,8 +264,7 @@ def create_interface():
265
  "wav2vec2_xlsr": {"maximum": 24, "value": 24, "interactive": True},
266
  "wavlm_base": {"maximum": 12, "value": 12, "interactive": True},
267
  "wav2vec2_base": {"maximum": 12, "value": 12, "interactive": True},
268
- "hubert_base": {"maximum": 12, "value": 12, "interactive": True},
269
- "ast": {"maximum": 12, "value": 12, "interactive": True}
270
  }
271
 
272
  config = model_configs.get(model_name, {"maximum": 12, "value": 12, "interactive": True})
@@ -308,4 +306,4 @@ def create_interface():
308
 
309
  if __name__ == "__main__":
310
  demo = create_interface()
311
- demo.launch()
 
88
  model_defaults = {
89
  "wavlm": 24, "wav2vec2": 24, "hubert": 24,
90
  "wavlm_base": 12, "wav2vec2_base": 12, "hubert_base": 12,
91
+ "wav2vec2_xlsr": 24
92
  }
93
  layer_final = layer if layer is not None else model_defaults.get(model_name, 12)
94
 
 
184
  | `wav2vec2_base` | Wav2Vec2 Base | 12 | Faster, good quality |
185
  | `hubert_base` | HuBERT Base | 12 | |
186
  | `wav2vec2_xlsr` | Wav2Vec2 XLSR-53 | 24 | Multilingual |
 
187
 
188
  ## Parameters
189
 
 
232
  model_dropdown = gr.Dropdown(
233
  choices=["raw", "wavlm", "wav2vec2", "hubert",
234
  "wavlm_base", "wav2vec2_base", "hubert_base",
235
+ "wav2vec2_xlsr"],
236
  value="wav2vec2_base",
237
  label="Select embedding model"
238
  )
 
264
  "wav2vec2_xlsr": {"maximum": 24, "value": 24, "interactive": True},
265
  "wavlm_base": {"maximum": 12, "value": 12, "interactive": True},
266
  "wav2vec2_base": {"maximum": 12, "value": 12, "interactive": True},
267
+ "hubert_base": {"maximum": 12, "value": 12, "interactive": True}
 
268
  }
269
 
270
  config = model_configs.get(model_name, {"maximum": 12, "value": 12, "interactive": True})
 
306
 
307
  if __name__ == "__main__":
308
  demo = create_interface()
309
+ demo.launch()
argshield.py CHANGED
@@ -17,7 +17,6 @@ MODEL_DEFAULT_LAYER = {
17
  "wav2vec2_base": 12,
18
  "hubert_base": 12,
19
  "wav2vec2_xlsr": 24,
20
- "ast": 12,
21
  }
22
 
23
  def _read_manifest_json(path: Path):
@@ -88,7 +87,7 @@ def _parse_args():
88
  required=True,
89
  help=("Embedding model. Choices: "
90
  "raw, wavlm, wav2vec2, hubert, wavlm_base, wav2vec2_base, "
91
- "hubert_base, wav2vec2_xlsr, ast"),
92
  )
93
  parser.add_argument(
94
  "--layer",
@@ -141,4 +140,4 @@ def _validate_gpus(max_gpus_opt):
141
  raise SystemExit("--max-gpus must be an integer >= 0.")
142
  if mg < 0:
143
  raise SystemExit("--max-gpus must be >= 0.")
144
- return mg
 
17
  "wav2vec2_base": 12,
18
  "hubert_base": 12,
19
  "wav2vec2_xlsr": 24,
 
20
  }
21
 
22
  def _read_manifest_json(path: Path):
 
87
  required=True,
88
  help=("Embedding model. Choices: "
89
  "raw, wavlm, wav2vec2, hubert, wavlm_base, wav2vec2_base, "
90
+ "hubert_base, wav2vec2_xlsr"),
91
  )
92
  parser.add_argument(
93
  "--layer",
 
140
  raise SystemExit("--max-gpus must be an integer >= 0.")
141
  if mg < 0:
142
  raise SystemExit("--max-gpus must be >= 0.")
143
+ return mg
config.py CHANGED
@@ -1,30 +1,33 @@
1
- import os
2
- import torch
3
-
4
- import warnings
5
- warnings.filterwarnings(
6
- "ignore",
7
- category=UserWarning,
8
- message=r"^expandable_segments not supported on this platform"
9
- )
10
-
11
- SR = 16_000
12
- RESULTS_ROOT = "results"
13
- BATCH_SIZE = 2
14
- ENERGY_WIN_MS = 20
15
- ENERGY_HOP_MS = 20
16
- SILENCE_RATIO = 0.1
17
- EPS = 1e-4
18
- COV_TOL = 1e-6
19
-
20
- DEFAULT_LAYER = 2
21
- DEFAULT_ADD_CI = True
22
- DEFAULT_DELTA_CI = 0.05
23
- DEFAULT_ALPHA = 1.0
24
-
25
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
26
- os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
27
-
28
- torch.backends.cudnn.benchmark = True
29
- torch.backends.cudnn.deterministic = False
30
- torch.backends.cudnn.enabled = True
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ import warnings
5
+ warnings.filterwarnings(
6
+ "ignore",
7
+ category=UserWarning,
8
+ message=r"^expandable_segments not supported on this platform"
9
+ )
10
+
11
+ SR = 16_000
12
+ RESULTS_ROOT = "results"
13
+ BATCH_SIZE = 2
14
+ ENERGY_WIN_MS = 20
15
+ ENERGY_HOP_MS = 20
16
+ SILENCE_RATIO = 0.1
17
+ EPS = 1e-4
18
+ COV_TOL = 1e-6
19
+
20
+ DEFAULT_LAYER = 2
21
+ DEFAULT_ADD_CI = True
22
+ DEFAULT_DELTA_CI = 0.05
23
+ DEFAULT_ALPHA = 1.0
24
+
25
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True,garbage_collection_threshold:0.6"
26
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
27
+
28
+ torch.backends.cudnn.benchmark = True
29
+ torch.backends.cudnn.deterministic = False
30
+ torch.backends.cudnn.enabled = True
31
+
32
+ if torch.cuda.is_available():
33
+ torch.cuda.set_per_process_memory_fraction(0.8)
engine.py CHANGED
@@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
4
  from datetime import datetime
5
  import librosa
6
  import pandas as pd
 
7
  from audio import (
8
  assign_outputs_to_refs_by_corr,
9
  loudness_normalize,
@@ -73,7 +74,7 @@ def compute_mapss_measures(
73
 
74
  if algos is None:
75
  algos_to_run = sorted(
76
- {algo for m in canon_mix for algo in (m.systems or {}).keys()}
77
  )
78
  else:
79
  algos_to_run = list(algos)
@@ -132,6 +133,7 @@ def compute_mapss_measures(
132
  win = int(ENERGY_WIN_MS * SR / 1000)
133
  hop = int(ENERGY_HOP_MS * SR / 1000)
134
  voiced_mask_mix = []
 
135
 
136
  for i, mix in enumerate(mixture_entries):
137
  if verbose:
@@ -142,6 +144,7 @@ def compute_mapss_measures(
142
  refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
143
  mask = make_union_voiced_mask(refs_for_mix, win, hop)
144
  voiced_mask_mix.append(mask.cpu())
 
145
  # Explicitly delete GPU tensors
146
  for ref in refs_for_mix:
147
  del ref
@@ -150,21 +153,33 @@ def compute_mapss_measures(
150
  refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
151
  mask = make_union_voiced_mask(refs_for_mix, win, hop)
152
  voiced_mask_mix.append(mask.cpu())
 
153
 
154
  ordered_speakers = [e["id"] for e in flat_entries]
155
 
156
- for algo_idx, algo in enumerate(algos_to_run):
157
- if verbose:
158
- print(f"\nProcessing Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- algo_dir = os.path.join(exp_root, algo)
161
- os.makedirs(algo_dir, exist_ok=True)
162
 
163
- all_outs = {}
164
- missing = []
165
 
166
- for mix_idx, mix in enumerate(mixture_entries):
167
- for e in mix:
168
  assigned_path = e.get("outs", {}).get(algo)
169
  if assigned_path is None:
170
  missing.append((e["mixture"], e["id"]))
@@ -173,49 +188,49 @@ def compute_mapss_measures(
173
  wav, _ = librosa.load(str(assigned_path), sr=SR)
174
  all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
175
 
176
- if missing:
177
- msg = f"[{algo}] missing outputs for {len(missing)} speaker(s)"
178
- if on_missing == "error":
179
- raise FileNotFoundError(msg)
180
- else:
181
- if verbose:
182
- warnings.warn(msg + " Skipping those speakers.")
183
-
184
- if not all_outs:
185
- if verbose:
186
- warnings.warn(f"[{algo}] No outputs provided. Skipping algorithm.")
187
- continue
188
-
189
- ps_ts = {m: {s: [] for s in ordered_speakers} for m in models}
190
- pm_ts = {m: {s: [] for s in ordered_speakers} for m in models}
191
- ps_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
192
- ps_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
193
- pm_bias_ts = {m: {s: [] for s in ordered_speakers} for m in models}
194
- pm_prob_ts = {m: {s: [] for s in ordered_speakers} for m in models}
195
-
196
- for model_idx, mname in enumerate(models):
197
- if verbose:
198
- print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}")
199
 
200
- for metric_type in ["PS", "PM"]:
201
- clear_gpu_memory()
202
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
205
- get_gpu_memory_info(verbose)
 
206
 
207
- embs_by_mix = {}
208
- labels_by_mix = {}
209
 
210
- for k, mix in enumerate(mixture_entries):
211
- speakers_this_mix = [e for e in mix if e["id"] in all_outs]
212
  if not speakers_this_mix:
213
  continue
214
 
215
  if verbose:
216
- print(
217
- f"Processing mixture {k + 1}/{len(mixture_entries)} for {metric_type}"
218
- )
219
 
220
  all_signals_mix = []
221
  all_masks_mix = []
@@ -240,7 +255,7 @@ def compute_mapss_measures(
240
  sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
241
  lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
242
 
243
- masks = [voiced_mask_mix[k]] * len(sigs)
244
  all_signals_mix.extend(sigs)
245
  all_masks_mix.extend(masks)
246
  all_labels_mix.extend([f"{s}-{l}" for l in lbls])
@@ -269,12 +284,124 @@ def compute_mapss_measures(
269
 
270
  if embeddings_list:
271
  embeddings = torch.cat(embeddings_list, dim=0)
272
- embs_by_mix[k] = embeddings
273
- labels_by_mix[k] = all_labels_mix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  except Exception as ex:
276
  if verbose:
277
- print(f" ERROR processing mixture {k + 1}: {ex}")
278
  continue
279
  finally:
280
  # Always clean up after processing a mixture
@@ -284,155 +411,73 @@ def compute_mapss_measures(
284
  clear_gpu_memory()
285
  gc.collect()
286
 
287
- if verbose:
288
- print(f" Computing {metric_type} scores for {mname}...")
289
-
290
- # Process mixtures with their stored embeddings and labels
291
- with ThreadPoolExecutor(
292
- max_workers=min(2, ngpu if ngpu > 0 else 1)
293
- ) as executor:
294
- for k in range(len(mixture_entries)):
295
- if k not in embs_by_mix:
296
- continue
297
-
298
- E, L, D = embs_by_mix[k].shape
299
- if L == 0:
300
- if verbose:
301
- print(f" WARNING: mixture {k + 1} produced 0 frames after masking; skipping.")
302
- continue
303
-
304
- # Get the labels for this mixture
305
- labels_for_mix = labels_by_mix[k]
306
-
307
- def process_frame(f, embeddings_mix, labels_mix):
308
- try:
309
- frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
310
-
311
- if add_ci:
312
- coords_d, coords_c, eigvals, k_sub_gauss = (
313
- gpu_distributor.execute_on_gpu(
314
- diffusion_map_torch,
315
- frame_emb,
316
- labels_mix,
317
- alpha=alpha,
318
- eig_solver="full",
319
- return_eigs=True,
320
- return_complement=True,
321
- return_cval=add_ci,
322
- )
323
- )
324
- else:
325
- coords_d = gpu_distributor.execute_on_gpu(
326
- diffusion_map_torch,
327
- frame_emb,
328
- labels_mix,
329
- alpha=alpha,
330
- eig_solver="full",
331
- return_eigs=False,
332
- return_complement=False,
333
- return_cval=False,
334
- )
335
- coords_c = None
336
- eigvals = None
337
- k_sub_gauss = 1
338
-
339
- if metric_type == "PS":
340
- score = compute_ps(
341
- coords_d, labels_mix, max_gpus
342
- )
343
- bias = prob = None
344
- if add_ci:
345
- bias, prob = ps_ci_components_full(
346
- coords_d,
347
- coords_c,
348
- eigvals,
349
- labels_mix,
350
- delta=DEFAULT_DELTA_CI,
351
- )
352
- return f, "PS", score, bias, prob
353
- else:
354
- score = compute_pm(
355
- coords_d, labels_mix, "gamma", max_gpus
356
- )
357
- bias = prob = None
358
- if add_ci:
359
- bias, prob = pm_ci_components_full(
360
- coords_d,
361
- coords_c,
362
- eigvals,
363
- labels_mix,
364
- delta=DEFAULT_DELTA_CI,
365
- K=k_sub_gauss,
366
- )
367
- return f, "PM", score, bias, prob
368
-
369
- except Exception as ex:
370
- if verbose:
371
- print(f" ERROR frame {f + 1}: {ex}")
372
- return None
373
-
374
- futures = [
375
- executor.submit(process_frame, f, embs_by_mix[k], labels_for_mix)
376
- for f in range(L)
377
- ]
378
- for fut in futures:
379
- result = fut.result()
380
- if result is None:
381
- continue
382
-
383
- f, metric, score, bias, prob = result
384
-
385
- if metric == "PS":
386
- for sp in score:
387
- ps_ts[mname][sp].append(score[sp])
388
- if add_ci and bias is not None:
389
- ps_bias_ts[mname][sp].append(bias[sp])
390
- ps_prob_ts[mname][sp].append(prob[sp])
391
- else:
392
- for sp in score:
393
- pm_ts[mname][sp].append(score[sp])
394
- if add_ci and bias is not None:
395
- pm_bias_ts[mname][sp].append(bias[sp])
396
- pm_prob_ts[mname][sp].append(prob[sp])
397
-
398
- # Clean up after processing all mixtures for this metric
399
- del embs_by_mix, labels_by_mix
400
  clear_gpu_memory()
401
  gc.collect()
402
 
403
- del model_wrapper
404
- clear_gpu_memory()
405
- gc.collect()
406
-
 
 
 
 
 
 
 
 
407
  if verbose:
408
- print(f" Saving results for {algo}...")
409
-
410
- for m in models:
411
-
412
- def _pad(vec, n):
413
- return vec + [np.nan] * (n - len(vec))
414
-
415
- max_len = 0
416
- for s in ordered_speakers:
417
- max_len = max(max_len, len(ps_ts[m][s]), len(pm_ts[m][s]))
418
-
419
- pd.DataFrame(
420
- {s: _pad(ps_ts[m][s], max_len) for s in ordered_speakers}
421
- ).to_csv(os.path.join(algo_dir, f"ps_scores_{m}.csv"), index=False)
422
-
423
- pd.DataFrame(
424
- {s: _pad(pm_ts[m][s], max_len) for s in ordered_speakers}
425
- ).to_csv(os.path.join(algo_dir, f"pm_scores_{m}.csv"), index=False)
426
-
427
- if add_ci:
428
- ci_cols = {}
429
- for s in ordered_speakers:
430
- ci_cols[f"{s}_ps_bias"] = _pad(ps_bias_ts[m][s], max_len)
431
- ci_cols[f"{s}_ps_prob"] = _pad(ps_prob_ts[m][s], max_len)
432
- ci_cols[f"{s}_pm_bias"] = _pad(pm_bias_ts[m][s], max_len)
433
- ci_cols[f"{s}_pm_prob"] = _pad(pm_prob_ts[m][s], max_len)
434
- pd.DataFrame(ci_cols).to_csv(
435
- os.path.join(algo_dir, f"ci_{m}.csv"), index=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  )
437
 
438
  del all_outs
 
4
  from datetime import datetime
5
  import librosa
6
  import pandas as pd
7
+ import numpy as np
8
  from audio import (
9
  assign_outputs_to_refs_by_corr,
10
  loudness_normalize,
 
74
 
75
  if algos is None:
76
  algos_to_run = sorted(
77
+ {algo for algo in canon_mix[0].systems.keys()} if canon_mix and canon_mix[0].systems else []
78
  )
79
  else:
80
  algos_to_run = list(algos)
 
133
  win = int(ENERGY_WIN_MS * SR / 1000)
134
  hop = int(ENERGY_HOP_MS * SR / 1000)
135
  voiced_mask_mix = []
136
+ total_frames_per_mix = [] # Store total frames for each mixture
137
 
138
  for i, mix in enumerate(mixture_entries):
139
  if verbose:
 
144
  refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
145
  mask = make_union_voiced_mask(refs_for_mix, win, hop)
146
  voiced_mask_mix.append(mask.cpu())
147
+ total_frames_per_mix.append(mask.shape[0])
148
  # Explicitly delete GPU tensors
149
  for ref in refs_for_mix:
150
  del ref
 
153
  refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
154
  mask = make_union_voiced_mask(refs_for_mix, win, hop)
155
  voiced_mask_mix.append(mask.cpu())
156
+ total_frames_per_mix.append(mask.shape[0])
157
 
158
  ordered_speakers = [e["id"] for e in flat_entries]
159
 
160
+ # Initialize storage for all mixtures and algorithms
161
+ all_mixture_results = {} # mixture_id -> {algo -> {model -> data}}
162
+
163
+ for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)):
164
+ mixture_id = mix_canon.mixture_id
165
+ all_mixture_results[mixture_id] = {}
166
+
167
+ # Get total frames for this mixture
168
+ total_frames = total_frames_per_mix[mix_idx]
169
+
170
+ # Get speakers for this mixture
171
+ mixture_speakers = [e["id"] for e in mix_entries]
172
+
173
+ for algo_idx, algo in enumerate(algos_to_run):
174
+ if verbose:
175
+ print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
176
 
177
+ # Remove the old algo_dir creation here - we don't need these empty folders anymore
 
178
 
179
+ all_outs = {}
180
+ missing = []
181
 
182
+ for e in mix_entries:
 
183
  assigned_path = e.get("outs", {}).get(algo)
184
  if assigned_path is None:
185
  missing.append((e["mixture"], e["id"]))
 
188
  wav, _ = librosa.load(str(assigned_path), sr=SR)
189
  all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))
190
 
191
+ if missing:
192
+ msg = f"[{algo}] missing outputs for {len(missing)} speaker(s) in mixture {mixture_id}"
193
+ if on_missing == "error":
194
+ raise FileNotFoundError(msg)
195
+ else:
196
+ if verbose:
197
+ warnings.warn(msg + " Skipping those speakers.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ if not all_outs:
200
+ if verbose:
201
+ warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.")
202
+ continue
203
+
204
+ # Initialize storage for this algorithm
205
+ if algo not in all_mixture_results[mixture_id]:
206
+ all_mixture_results[mixture_id][algo] = {}
207
+
208
+ # Initialize frame-wise storage with NaN for all frames
209
+ ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
210
+ pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
211
+ ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
212
+ ps_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
213
+ pm_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
214
+ pm_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
215
+
216
+ for model_idx, mname in enumerate(models):
217
+ if verbose:
218
+ print(f" Processing Model {model_idx + 1}/{len(models)}: {mname}")
219
 
220
+ for metric_type in ["PS", "PM"]:
221
+ clear_gpu_memory()
222
+ gc.collect()
223
 
224
+ model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
225
+ get_gpu_memory_info(verbose)
226
 
227
+ # Process only this mixture
228
+ speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs]
229
  if not speakers_this_mix:
230
  continue
231
 
232
  if verbose:
233
+ print(f" Processing {metric_type} for mixture {mixture_id}")
 
 
234
 
235
  all_signals_mix = []
236
  all_masks_mix = []
 
255
  sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
256
  lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]
257
 
258
+ masks = [voiced_mask_mix[mix_idx]] * len(sigs)
259
  all_signals_mix.extend(sigs)
260
  all_masks_mix.extend(masks)
261
  all_labels_mix.extend([f"{s}-{l}" for l in lbls])
 
284
 
285
  if embeddings_list:
286
  embeddings = torch.cat(embeddings_list, dim=0)
287
+ E, L, D = embeddings.shape
288
+
289
+ if L == 0:
290
+ if verbose:
291
+ print(
292
+ f" WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
293
+ continue
294
+
295
+ # Get valid frame indices for this mixture
296
+ mask = voiced_mask_mix[mix_idx]
297
+ valid_frame_indices = torch.where(mask)[0].tolist()
298
+
299
+ if verbose:
300
+ print(f" Computing {metric_type} scores for {mname}...")
301
+
302
+ # Process frames with their stored embeddings and labels
303
+ with ThreadPoolExecutor(
304
+ max_workers=min(2, ngpu if ngpu > 0 else 1)
305
+ ) as executor:
306
+
307
+ def process_frame(f, frame_idx, embeddings_mix, labels_mix):
308
+ try:
309
+ frame_emb = embeddings_mix[:, f, :].detach().cpu().numpy()
310
+
311
+ if add_ci:
312
+ coords_d, coords_c, eigvals, k_sub_gauss = (
313
+ gpu_distributor.execute_on_gpu(
314
+ diffusion_map_torch,
315
+ frame_emb,
316
+ labels_mix,
317
+ alpha=alpha,
318
+ eig_solver="full",
319
+ return_eigs=True,
320
+ return_complement=True,
321
+ return_cval=add_ci,
322
+ )
323
+ )
324
+ else:
325
+ coords_d = gpu_distributor.execute_on_gpu(
326
+ diffusion_map_torch,
327
+ frame_emb,
328
+ labels_mix,
329
+ alpha=alpha,
330
+ eig_solver="full",
331
+ return_eigs=False,
332
+ return_complement=False,
333
+ return_cval=False,
334
+ )
335
+ coords_c = None
336
+ eigvals = None
337
+ k_sub_gauss = 1
338
+
339
+ if metric_type == "PS":
340
+ score = compute_ps(
341
+ coords_d, labels_mix, max_gpus
342
+ )
343
+ bias = prob = None
344
+ if add_ci:
345
+ bias, prob = ps_ci_components_full(
346
+ coords_d,
347
+ coords_c,
348
+ eigvals,
349
+ labels_mix,
350
+ delta=DEFAULT_DELTA_CI,
351
+ )
352
+ return frame_idx, "PS", score, bias, prob
353
+ else:
354
+ score = compute_pm(
355
+ coords_d, labels_mix, "gamma", max_gpus
356
+ )
357
+ bias = prob = None
358
+ if add_ci:
359
+ bias, prob = pm_ci_components_full(
360
+ coords_d,
361
+ coords_c,
362
+ eigvals,
363
+ labels_mix,
364
+ delta=DEFAULT_DELTA_CI,
365
+ K=k_sub_gauss,
366
+ )
367
+ return frame_idx, "PM", score, bias, prob
368
+
369
+ except Exception as ex:
370
+ if verbose:
371
+ print(f" ERROR frame {frame_idx}: {ex}")
372
+ return None
373
+
374
+ futures = [
375
+ executor.submit(process_frame, f, valid_frame_indices[f], embeddings,
376
+ all_labels_mix)
377
+ for f in range(L)
378
+ ]
379
+
380
+ for fut in futures:
381
+ result = fut.result()
382
+ if result is None:
383
+ continue
384
+
385
+ frame_idx, metric, score, bias, prob = result
386
+
387
+ if metric == "PS":
388
+ for sp in score:
389
+ if sp in mixture_speakers:
390
+ ps_frames[mname][sp][frame_idx] = score[sp]
391
+ if add_ci and bias is not None:
392
+ ps_bias_frames[mname][sp][frame_idx] = bias[sp]
393
+ ps_prob_frames[mname][sp][frame_idx] = prob[sp]
394
+ else:
395
+ for sp in score:
396
+ if sp in mixture_speakers:
397
+ pm_frames[mname][sp][frame_idx] = score[sp]
398
+ if add_ci and bias is not None:
399
+ pm_bias_frames[mname][sp][frame_idx] = bias[sp]
400
+ pm_prob_frames[mname][sp][frame_idx] = prob[sp]
401
 
402
  except Exception as ex:
403
  if verbose:
404
+ print(f" ERROR processing mixture {mixture_id}: {ex}")
405
  continue
406
  finally:
407
  # Always clean up after processing a mixture
 
411
  clear_gpu_memory()
412
  gc.collect()
413
 
414
+ del model_wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  clear_gpu_memory()
416
  gc.collect()
417
 
418
+ # Store results for this mixture and algorithm
419
+ all_mixture_results[mixture_id][algo][mname] = {
420
+ 'ps_frames': ps_frames[mname],
421
+ 'pm_frames': pm_frames[mname],
422
+ 'ps_bias_frames': ps_bias_frames[mname] if add_ci else None,
423
+ 'ps_prob_frames': ps_prob_frames[mname] if add_ci else None,
424
+ 'pm_bias_frames': pm_bias_frames[mname] if add_ci else None,
425
+ 'pm_prob_frames': pm_prob_frames[mname] if add_ci else None,
426
+ 'total_frames': total_frames
427
+ }
428
+
429
+ # Save results for this mixture after processing all algorithms
430
  if verbose:
431
+ print(f" Saving results for mixture {mixture_id}...")
432
+
433
+ # Create timestamps in milliseconds - using lowercase hop
434
+ timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)]
435
+
436
+ for model in models:
437
+ # Prepare PS data
438
+ ps_data = {'timestamp_ms': timestamps_ms}
439
+ pm_data = {'timestamp_ms': timestamps_ms}
440
+ ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None
441
+
442
+ # Combine data from all algorithms for this mixture
443
+ for algo in algos_to_run:
444
+ if algo not in all_mixture_results[mixture_id]:
445
+ continue
446
+ if model not in all_mixture_results[mixture_id][algo]:
447
+ continue
448
+
449
+ model_data = all_mixture_results[mixture_id][algo][model]
450
+
451
+ # Add PS data
452
+ for speaker in mixture_speakers:
453
+ col_name = f"{algo}_{speaker}"
454
+ ps_data[col_name] = model_data['ps_frames'][speaker]
455
+ pm_data[col_name] = model_data['pm_frames'][speaker]
456
+
457
+ if add_ci and ci_data is not None:
458
+ ci_data[f"{algo}_{speaker}_ps_bias"] = model_data['ps_bias_frames'][speaker]
459
+ ci_data[f"{algo}_{speaker}_ps_prob"] = model_data['ps_prob_frames'][speaker]
460
+ ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker]
461
+ ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker]
462
+
463
+ # Save CSV files for this mixture
464
+ mixture_dir = os.path.join(exp_root, mixture_id)
465
+ os.makedirs(mixture_dir, exist_ok=True)
466
+
467
+ pd.DataFrame(ps_data).to_csv(
468
+ os.path.join(mixture_dir, f"ps_scores_{model}.csv"),
469
+ index=False
470
+ )
471
+
472
+ pd.DataFrame(pm_data).to_csv(
473
+ os.path.join(mixture_dir, f"pm_scores_{model}.csv"),
474
+ index=False
475
+ )
476
+
477
+ if add_ci and ci_data is not None:
478
+ pd.DataFrame(ci_data).to_csv(
479
+ os.path.join(mixture_dir, f"ci_{model}.csv"),
480
+ index=False
481
  )
482
 
483
  del all_outs
models.py CHANGED
@@ -1,360 +1,324 @@
1
- import queue
2
- import threading
3
- import gc
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from transformers import (
8
- HubertModel,
9
- Wav2Vec2FeatureExtractor,
10
- Wav2Vec2Model,
11
- WavLMModel,
12
- ASTModel,
13
- AutoFeatureExtractor,
14
- )
15
-
16
- from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
17
- from utils import get_gpu_count
18
-
19
-
20
- class BalancedDualGPUModel:
21
-
22
- def __init__(self, model_name, layer, max_gpus=None):
23
- self.layer = layer
24
- self.models = []
25
- self.extractors = []
26
- self.devices = []
27
- ngpu = get_gpu_count(max_gpus)
28
-
29
- # This class should only be used when GPUs are available
30
- if ngpu == 0:
31
- raise RuntimeError("BalancedDualGPUModel requires at least 1 GPU")
32
-
33
- for gpu_id in range(min(ngpu, 2)):
34
- device = f"cuda:{gpu_id}"
35
- self.devices.append(device)
36
- ckpt, cls, _ = get_model_config(layer)[model_name]
37
- if cls is ASTModel:
38
- extractor = AutoFeatureExtractor.from_pretrained(ckpt)
39
- else:
40
- extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
41
-
42
- attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
43
- # Use float32 for better compatibility
44
- model = cls.from_pretrained(
45
- ckpt,
46
- output_hidden_states=True,
47
- use_safetensors=True,
48
- torch_dtype=torch.float32, # Changed from float16
49
- low_cpu_mem_usage=True,
50
- attn_implementation=attn_impl
51
- )
52
- model.eval()
53
- model = model.to(device)
54
-
55
- for param in model.parameters():
56
- param.requires_grad = False
57
-
58
- self.extractors.append(extractor)
59
- self.models.append(model)
60
-
61
- self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
62
- self.result_queue = queue.Queue()
63
- self.workers = []
64
- for i in range(len(self.devices)):
65
- worker = threading.Thread(target=self._gpu_worker, args=(i,))
66
- worker.daemon = True
67
- worker.start()
68
- self.workers.append(worker)
69
-
70
- def _gpu_worker(self, gpu_id):
71
- device = self.devices[gpu_id]
72
- model = self.models[gpu_id]
73
- extractor = self.extractors[gpu_id]
74
- while True:
75
- task = self.gpu_queues[gpu_id].get()
76
- if task is None:
77
- break
78
- signals, masks, use_mlm, task_id = task
79
- try:
80
- inputs = extractor(
81
- signals, sampling_rate=SR, return_tensors="pt", padding=True
82
- )
83
- input_values = inputs.input_values.to(device, non_blocking=True)
84
-
85
- torch.cuda.empty_cache()
86
-
87
- orig_mode = model.training
88
- model.train() if use_mlm else model.eval()
89
- with torch.no_grad():
90
- # Only use autocast on actual GPUs with float16 support
91
- if torch.cuda.is_available() and device.startswith('cuda'):
92
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
93
- hs = model(
94
- input_values, output_hidden_states=True
95
- ).hidden_states[self.layer]
96
- else:
97
- hs = model(
98
- input_values, output_hidden_states=True
99
- ).hidden_states[self.layer]
100
- model.train(orig_mode)
101
-
102
- B, T, D = hs.shape
103
- keep = []
104
- for b in range(B):
105
- mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
106
- mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
107
- keep.append(hs[b][mask_t].cpu())
108
-
109
- # Aggressive cleanup
110
- del hs, input_values, inputs
111
- torch.cuda.empty_cache()
112
-
113
- if keep:
114
- L_max = max(x.shape[0] for x in keep)
115
- keep_padded = [
116
- F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
117
- ]
118
- result = torch.stack(keep_padded, dim=0)
119
- else:
120
- result = torch.empty(0, 0, 0)
121
- self.result_queue.put((task_id, result))
122
- except Exception as e:
123
- self.result_queue.put((task_id, e))
124
- finally:
125
- # Always clear cache after processing
126
- torch.cuda.empty_cache()
127
-
128
- def process_batch(self, signals, masks, use_mlm=False):
129
- if not signals:
130
- return torch.empty(0, 0, 0)
131
- batch_size = len(signals)
132
- split = (batch_size + len(self.devices) - 1) // len(self.devices)
133
- results = {}
134
- task_id = 0
135
- for i in range(0, batch_size, split):
136
- end = min(i + split, batch_size)
137
- gpu_id = (i // split) % len(self.devices)
138
- self.gpu_queues[gpu_id].put(
139
- (signals[i:end], masks[i:end], use_mlm, task_id)
140
- )
141
- task_id += 1
142
- for _ in range(task_id):
143
- tid, result = self.result_queue.get()
144
- if isinstance(result, Exception):
145
- raise result
146
- results[tid] = result
147
- parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
148
- return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
149
-
150
- def cleanup(self):
151
- """Explicit cleanup method"""
152
- for q in self.gpu_queues:
153
- q.put(None)
154
- for w in self.workers:
155
- w.join(timeout=5.0)
156
- for model in self.models:
157
- del model
158
- for extractor in self.extractors:
159
- del extractor
160
- self.models.clear()
161
- self.extractors.clear()
162
- torch.cuda.empty_cache()
163
- gc.collect()
164
-
165
- def __del__(self):
166
- self.cleanup()
167
-
168
-
169
- # NO CACHE - we need to clean up models properly between runs
170
- def get_model_config(layer):
171
- return {
172
- "raw": (None, None, None),
173
- "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
174
- "wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
175
- "hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
176
- "wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
177
- "wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
178
- "hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
179
- "wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
180
- "ast": ("MIT/ast-finetuned-audioset-10-10-0.4593", ASTModel, layer),
181
- }
182
-
183
-
184
- # Store loaded models globally to properly manage them
185
- _loaded_models = {}
186
-
187
-
188
- def load_model(name, layer, max_gpus=None):
189
- global _loaded_models
190
-
191
- # Clean up any previously loaded models first
192
- if _loaded_models:
193
- for key, model_data in _loaded_models.items():
194
- if isinstance(model_data, tuple) and len(model_data) == 2:
195
- if isinstance(model_data[0], BalancedDualGPUModel):
196
- model_data[0].cleanup()
197
- elif isinstance(model_data[0], tuple):
198
- # Single GPU model
199
- _, model = model_data[0]
200
- del model
201
- _loaded_models.clear()
202
- if torch.cuda.is_available():
203
- torch.cuda.empty_cache()
204
- gc.collect()
205
-
206
- if name.lower() in {"raw", "waveform"}:
207
- return "raw", layer
208
-
209
- ngpu = get_gpu_count(max_gpus)
210
-
211
- # Only use BalancedDualGPUModel if we have multiple GPUs
212
- if ngpu > 1:
213
- model = BalancedDualGPUModel(name, layer, max_gpus)
214
- _loaded_models[name] = (model, layer)
215
- return model, layer
216
- else:
217
- ckpt, cls, layer_eff = get_model_config(layer)[name]
218
- if cls is ASTModel:
219
- extractor = AutoFeatureExtractor.from_pretrained(ckpt)
220
- else:
221
- extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
222
-
223
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
224
- attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
225
-
226
- # CRITICAL FIX: Always use float32 for CPU compatibility
227
- model = cls.from_pretrained(
228
- ckpt,
229
- output_hidden_states=True,
230
- use_safetensors=True,
231
- torch_dtype=torch.float32, # Changed from float16 to float32
232
- low_cpu_mem_usage=True,
233
- attn_implementation=attn_impl
234
- )
235
- model.eval()
236
- model = model.to(device)
237
-
238
- for param in model.parameters():
239
- param.requires_grad = False
240
-
241
- model_tuple = ((extractor, model), layer_eff)
242
- _loaded_models[name] = model_tuple
243
- return (extractor, model), layer_eff
244
-
245
-
246
- def cleanup_all_models():
247
- """Call this at the end of each experiment to ensure complete cleanup"""
248
- global _loaded_models
249
- if _loaded_models:
250
- for key, model_data in _loaded_models.items():
251
- if isinstance(model_data, tuple) and len(model_data) == 2:
252
- if isinstance(model_data[0], BalancedDualGPUModel):
253
- model_data[0].cleanup()
254
- elif isinstance(model_data[0], tuple):
255
- # Single GPU model
256
- _, model = model_data[0]
257
- del model
258
- _loaded_models.clear()
259
- if torch.cuda.is_available():
260
- torch.cuda.empty_cache()
261
- gc.collect()
262
-
263
-
264
- def embed_batch_raw(signals, masks_audio):
265
- win = int(ENERGY_WIN_MS * SR / 1000)
266
- hop = int(ENERGY_HOP_MS * SR / 1000)
267
- reps, L_max = [], 0
268
- for sig_np, mask_np in zip(signals, masks_audio):
269
- x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
270
- frames = x.unfold(0, win, hop)
271
- mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
272
- keep = frames[mask] if mask.any() else frames[:1]
273
- reps.append(keep)
274
- L_max = max(L_max, keep.size(0))
275
- reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
276
- return torch.stack(reps, dim=0)
277
-
278
-
279
- def embed_batch_single_gpu(
280
- signals, masks_audio, extractor, model, layer, use_mlm=False
281
- ):
282
- if not signals:
283
- return torch.empty(0, 0, 0)
284
- device = next(model.parameters()).device
285
- is_cuda = device.type == 'cuda'
286
-
287
- max_batch = 2
288
- all_keeps = []
289
-
290
- for i in range(0, len(signals), max_batch):
291
- batch_signals = signals[i:i + max_batch]
292
- batch_masks = masks_audio[i:i + max_batch]
293
-
294
- inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
295
- input_values = inputs.input_values.to(device, non_blocking=is_cuda)
296
-
297
- orig_mode = model.training
298
- model.train() if use_mlm else model.eval()
299
-
300
- with torch.no_grad():
301
- # CRITICAL FIX: Don't use autocast on CPU
302
- if is_cuda:
303
- # On GPU, we can use autocast with float16 for speed
304
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
305
- hs = model(input_values, output_hidden_states=True).hidden_states[layer]
306
- else:
307
- # On CPU, just run the model directly without autocast
308
- hs = model(input_values, output_hidden_states=True).hidden_states[layer]
309
-
310
- model.train(orig_mode)
311
-
312
- B, T, D = hs.shape
313
- for b in range(B):
314
- mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
315
- mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
316
- all_keeps.append(hs[b][mask_t].cpu())
317
-
318
- # Aggressive cleanup
319
- del hs, input_values, inputs
320
- if is_cuda:
321
- torch.cuda.empty_cache()
322
-
323
- if all_keeps:
324
- L_max = max(x.shape[0] for x in all_keeps)
325
- keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
326
- result = torch.stack(keep_padded, dim=0)
327
- # Clean up intermediate lists
328
- del all_keeps, keep_padded
329
- return result
330
- else:
331
- return torch.empty(0, 0, 0)
332
-
333
-
334
- def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
335
- if model_wrapper == "raw":
336
- return embed_batch_raw(signals, masks_audio)
337
- if isinstance(model_wrapper, BalancedDualGPUModel):
338
- all_embeddings = []
339
- batch_size = min(BATCH_SIZE, 2)
340
- for i in range(0, len(signals), batch_size):
341
- batch_emb = model_wrapper.process_batch(
342
- signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
343
- )
344
- if batch_emb.numel() > 0:
345
- all_embeddings.append(batch_emb)
346
- # Clear cache after each batch
347
- if torch.cuda.is_available():
348
- torch.cuda.empty_cache()
349
-
350
- if all_embeddings:
351
- result = torch.cat(all_embeddings, dim=0)
352
- del all_embeddings
353
- return result
354
- else:
355
- return torch.empty(0, 0, 0)
356
- else:
357
- extractor, model = model_wrapper
358
- return embed_batch_single_gpu(
359
- signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
360
  )
 
1
+ import queue
2
+ import threading
3
+ import gc
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import (
8
+ HubertModel,
9
+ Wav2Vec2FeatureExtractor,
10
+ Wav2Vec2Model,
11
+ WavLMModel,
12
+ )
13
+
14
+ from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
15
+ from utils import get_gpu_count
16
+
17
+
18
+ class BalancedDualGPUModel:
19
+
20
+ def __init__(self, model_name, layer, max_gpus=None):
21
+ self.layer = layer
22
+ self.models = []
23
+ self.extractors = []
24
+ self.devices = []
25
+ ngpu = get_gpu_count(max_gpus)
26
+
27
+ for gpu_id in range(min(ngpu, 2)):
28
+ device = f"cuda:{gpu_id}"
29
+ self.devices.append(device)
30
+ ckpt, cls, _ = get_model_config(layer)[model_name]
31
+ extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
32
+
33
+ attn_impl = "eager" if cls is WavLMModel else "sdpa"
34
+ model = cls.from_pretrained(
35
+ ckpt,
36
+ output_hidden_states=True,
37
+ use_safetensors=True,
38
+ torch_dtype=torch.float16,
39
+ low_cpu_mem_usage=True,
40
+ attn_implementation=attn_impl
41
+ )
42
+ model.eval()
43
+ model = model.to(device)
44
+
45
+ for param in model.parameters():
46
+ param.requires_grad = False
47
+
48
+ self.extractors.append(extractor)
49
+ self.models.append(model)
50
+
51
+ self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
52
+ self.result_queue = queue.Queue()
53
+ self.workers = []
54
+ for i in range(len(self.devices)):
55
+ worker = threading.Thread(target=self._gpu_worker, args=(i,))
56
+ worker.daemon = True
57
+ worker.start()
58
+ self.workers.append(worker)
59
+
60
+ def _gpu_worker(self, gpu_id):
61
+ device = self.devices[gpu_id]
62
+ model = self.models[gpu_id]
63
+ extractor = self.extractors[gpu_id]
64
+ while True:
65
+ task = self.gpu_queues[gpu_id].get()
66
+ if task is None:
67
+ break
68
+ signals, masks, use_mlm, task_id = task
69
+ try:
70
+ inputs = extractor(
71
+ signals, sampling_rate=SR, return_tensors="pt", padding=True
72
+ )
73
+ input_values = inputs.input_values.to(device, non_blocking=True)
74
+
75
+ torch.cuda.empty_cache()
76
+
77
+ orig_mode = model.training
78
+ model.train() if use_mlm else model.eval()
79
+ with torch.no_grad():
80
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
81
+ hs = model(
82
+ input_values, output_hidden_states=True
83
+ ).hidden_states[self.layer]
84
+ model.train(orig_mode)
85
+
86
+ B, T, D = hs.shape
87
+ keep = []
88
+ for b in range(B):
89
+ mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
90
+ mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
91
+ keep.append(hs[b][mask_t].cpu())
92
+
93
+ # Aggressive cleanup
94
+ del hs, input_values, inputs
95
+ torch.cuda.empty_cache()
96
+
97
+ if keep:
98
+ L_max = max(x.shape[0] for x in keep)
99
+ keep_padded = [
100
+ F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
101
+ ]
102
+ result = torch.stack(keep_padded, dim=0)
103
+ else:
104
+ result = torch.empty(0, 0, 0)
105
+ self.result_queue.put((task_id, result))
106
+ except Exception as e:
107
+ self.result_queue.put((task_id, e))
108
+ finally:
109
+ # Always clear cache after processing
110
+ torch.cuda.empty_cache()
111
+
112
+ def process_batch(self, signals, masks, use_mlm=False):
113
+ if not signals:
114
+ return torch.empty(0, 0, 0)
115
+ batch_size = len(signals)
116
+ split = (batch_size + len(self.devices) - 1) // len(self.devices)
117
+ results = {}
118
+ task_id = 0
119
+ for i in range(0, batch_size, split):
120
+ end = min(i + split, batch_size)
121
+ gpu_id = (i // split) % len(self.devices)
122
+ self.gpu_queues[gpu_id].put(
123
+ (signals[i:end], masks[i:end], use_mlm, task_id)
124
+ )
125
+ task_id += 1
126
+ for _ in range(task_id):
127
+ tid, result = self.result_queue.get()
128
+ if isinstance(result, Exception):
129
+ raise result
130
+ results[tid] = result
131
+ parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
132
+ return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
133
+
134
+ def cleanup(self):
135
+ """Explicit cleanup method"""
136
+ for q in self.gpu_queues:
137
+ q.put(None)
138
+ for w in self.workers:
139
+ w.join(timeout=5.0)
140
+ for model in self.models:
141
+ del model
142
+ for extractor in self.extractors:
143
+ del extractor
144
+ self.models.clear()
145
+ self.extractors.clear()
146
+ torch.cuda.empty_cache()
147
+ gc.collect()
148
+
149
+ def __del__(self):
150
+ self.cleanup()
151
+
152
+
153
+ # NO CACHE - we need to clean up models properly between runs
154
+ def get_model_config(layer):
155
+ return {
156
+ "raw": (None, None, None),
157
+ "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
158
+ "wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
159
+ "hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
160
+ "wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
161
+ "wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
162
+ "hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
163
+ "wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
164
+ }
165
+
166
+
167
+ # Store loaded models globally to properly manage them
168
+ _loaded_models = {}
169
+
170
+
171
+ def load_model(name, layer, max_gpus=None):
172
+ global _loaded_models
173
+
174
+ # Clean up any previously loaded models first
175
+ if _loaded_models:
176
+ for key, model_data in _loaded_models.items():
177
+ if isinstance(model_data, tuple) and len(model_data) == 2:
178
+ if isinstance(model_data[0], BalancedDualGPUModel):
179
+ model_data[0].cleanup()
180
+ elif isinstance(model_data[0], tuple):
181
+ # Single GPU model
182
+ _, model = model_data[0]
183
+ del model
184
+ _loaded_models.clear()
185
+ torch.cuda.empty_cache()
186
+ gc.collect()
187
+
188
+ if name.lower() in {"raw", "waveform"}:
189
+ return "raw", layer
190
+
191
+ ngpu = get_gpu_count(max_gpus)
192
+ if ngpu > 1:
193
+ model = BalancedDualGPUModel(name, layer, max_gpus)
194
+ _loaded_models[name] = (model, layer)
195
+ return model, layer
196
+ else:
197
+ ckpt, cls, layer_eff = get_model_config(layer)[name]
198
+ extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
199
+
200
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
201
+ attn_impl = "eager" if cls is WavLMModel else "sdpa"
202
+ model = cls.from_pretrained(
203
+ ckpt,
204
+ output_hidden_states=True,
205
+ use_safetensors=True,
206
+ torch_dtype=torch.float16,
207
+ low_cpu_mem_usage=True,
208
+ attn_implementation=attn_impl
209
+ )
210
+ model.eval()
211
+ model = model.to(device)
212
+
213
+ for param in model.parameters():
214
+ param.requires_grad = False
215
+
216
+ model_tuple = ((extractor, model), layer_eff)
217
+ _loaded_models[name] = model_tuple
218
+ return (extractor, model), layer_eff
219
+
220
+
221
+ def cleanup_all_models():
222
+ """Call this at the end of each experiment to ensure complete cleanup"""
223
+ global _loaded_models
224
+ if _loaded_models:
225
+ for key, model_data in _loaded_models.items():
226
+ if isinstance(model_data, tuple) and len(model_data) == 2:
227
+ if isinstance(model_data[0], BalancedDualGPUModel):
228
+ model_data[0].cleanup()
229
+ elif isinstance(model_data[0], tuple):
230
+ # Single GPU model
231
+ _, model = model_data[0]
232
+ del model
233
+ _loaded_models.clear()
234
+ torch.cuda.empty_cache()
235
+ gc.collect()
236
+
237
+
238
+ def embed_batch_raw(signals, masks_audio):
239
+ win = int(ENERGY_WIN_MS * SR / 1000)
240
+ hop = int(ENERGY_HOP_MS * SR / 1000)
241
+ reps, L_max = [], 0
242
+ for sig_np, mask_np in zip(signals, masks_audio):
243
+ x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
244
+ frames = x.unfold(0, win, hop)
245
+ mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
246
+ keep = frames[mask] if mask.any() else frames[:1]
247
+ reps.append(keep)
248
+ L_max = max(L_max, keep.size(0))
249
+ reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
250
+ return torch.stack(reps, dim=0)
251
+
252
+
253
+ def embed_batch_single_gpu(
254
+ signals, masks_audio, extractor, model, layer, use_mlm=False
255
+ ):
256
+ if not signals:
257
+ return torch.empty(0, 0, 0)
258
+ device = next(model.parameters()).device
259
+
260
+ max_batch = 2
261
+ all_keeps = []
262
+
263
+ for i in range(0, len(signals), max_batch):
264
+ batch_signals = signals[i:i + max_batch]
265
+ batch_masks = masks_audio[i:i + max_batch]
266
+
267
+ inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
268
+ input_values = inputs.input_values.to(device, non_blocking=True)
269
+
270
+ orig_mode = model.training
271
+ model.train() if use_mlm else model.eval()
272
+
273
+ with torch.no_grad():
274
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
275
+ hs = model(input_values, output_hidden_states=True).hidden_states[layer]
276
+ model.train(orig_mode)
277
+
278
+ B, T, D = hs.shape
279
+ for b in range(B):
280
+ mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
281
+ mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
282
+ all_keeps.append(hs[b][mask_t].cpu())
283
+
284
+ # Aggressive cleanup
285
+ del hs, input_values, inputs
286
+ torch.cuda.empty_cache()
287
+
288
+ if all_keeps:
289
+ L_max = max(x.shape[0] for x in all_keeps)
290
+ keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
291
+ result = torch.stack(keep_padded, dim=0)
292
+ # Clean up intermediate lists
293
+ del all_keeps, keep_padded
294
+ return result
295
+ else:
296
+ return torch.empty(0, 0, 0)
297
+
298
+
299
+ def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
300
+ if model_wrapper == "raw":
301
+ return embed_batch_raw(signals, masks_audio)
302
+ if isinstance(model_wrapper, BalancedDualGPUModel):
303
+ all_embeddings = []
304
+ batch_size = min(BATCH_SIZE, 2)
305
+ for i in range(0, len(signals), batch_size):
306
+ batch_emb = model_wrapper.process_batch(
307
+ signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
308
+ )
309
+ if batch_emb.numel() > 0:
310
+ all_embeddings.append(batch_emb)
311
+ # Clear cache after each batch
312
+ torch.cuda.empty_cache()
313
+
314
+ if all_embeddings:
315
+ result = torch.cat(all_embeddings, dim=0)
316
+ del all_embeddings
317
+ return result
318
+ else:
319
+ return torch.empty(0, 0, 0)
320
+ else:
321
+ extractor, model = model_wrapper
322
+ return embed_batch_single_gpu(
323
+ signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  )