Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| import torchaudio | |
| from torchaudio import load as ta_load | |
| from torchaudio.functional import resample as ta_resample | |
| from zcodec.models import WavVAE | |
| # ------------------------- | |
| # Data structures | |
| # ------------------------- | |
| class WavVaeSpec: | |
| name: str | |
| wavvae_dir: str | |
| # ------------------------- | |
| # Utilities | |
| # ------------------------- | |
| def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: | |
| if path.is_file(): | |
| try: | |
| return json.load(path.open("r", encoding="utf-8")) | |
| except Exception: | |
| return None | |
| return None | |
| def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: | |
| cand = [ | |
| Path(checkpoint_dir) / "config.json", | |
| Path(checkpoint_dir) / "model_config.json", | |
| Path(checkpoint_dir) / "config.yaml", # shown as path only | |
| ] | |
| for p in cand: | |
| if p.exists(): | |
| if p.suffix == ".json": | |
| j = load_json_if_exists(p) | |
| if j is not None: | |
| return j | |
| else: | |
| return {"_config_file": str(p)} | |
| return {} | |
| def sanitize_name(s: str) -> str: | |
| return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) | |
| def ensure_mono_and_resample( | |
| wav: torch.Tensor, sr: int, target_sr: int | |
| ) -> Tuple[torch.Tensor, int]: | |
| if wav.ndim != 2: | |
| raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}") | |
| if wav.size(0) > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != target_sr: | |
| wav = ta_resample(wav, sr, target_sr) | |
| sr = target_sr | |
| return wav.to(torch.float32), sr | |
| def save_wav(path: Path, wav: torch.Tensor, sr: int): | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) | |
| wav = wav.clamp(-1, 1).contiguous().cpu() | |
| torchaudio.save( | |
| str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 | |
| ) | |
| def read_audio_manifest(txt_path: str) -> List[Path]: | |
| lines = Path(txt_path).read_text(encoding="utf-8").splitlines() | |
| files = [ | |
| Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#") | |
| ] | |
| return files | |
| def html_escape(s: str) -> str: | |
| return ( | |
| s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| ) | |
| def make_html( | |
| output_dir: Path, | |
| audio_files: List[Path], | |
| specs: List[WavVaeSpec], | |
| sr_by_model: Dict[str, int], | |
| wavvae_cfg: Dict[str, Dict[str, Any]], | |
| ) -> str: | |
| def player(src_rel: str) -> str: | |
| return f'<audio controls preload="none" src="{html_escape(src_rel)}"></audio>' | |
| # cards | |
| cards = [] | |
| for s in specs: | |
| cfg = wavvae_cfg.get(s.name, {}) | |
| cfg_short = json.dumps(cfg if cfg else {"_": "no JSON config found"}, indent=2)[ | |
| :1200 | |
| ] | |
| card = f""" | |
| <div class="model-card"> | |
| <h3>{html_escape(s.name)}</h3> | |
| <p><b>Sample rate</b>: {sr_by_model.get(s.name, "N/A")} Hz</p> | |
| <details><summary>WavVAE config</summary><pre>{html_escape(cfg_short)}</pre></details> | |
| </div> | |
| """ | |
| cards.append(card) | |
| css = """ | |
| body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } | |
| .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } | |
| .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } | |
| table { border-collapse: collapse; width: 100%; } | |
| th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } | |
| th { background: #fafafa; position: sticky; top: 0; } | |
| audio { width: 260px; } | |
| """ | |
| th = "<th>Input</th><th>Original</th>" + "".join( | |
| f"<th>{html_escape(s.name)}</th>" for s in specs | |
| ) | |
| rows = [] | |
| for af in audio_files: | |
| base = af.stem | |
| orig_rel = f"original/{html_escape(af.name)}" | |
| tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"] | |
| for s in specs: | |
| rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav" | |
| tds.append(f"<td>{player(rec_rel)}</td>") | |
| rows.append("<tr>" + "".join(tds) + "</tr>") | |
| html = f"""<!doctype html> | |
| <html> | |
| <head><meta charset="utf-8"/><title>WavVAE Comparison</title><style>{css}</style></head> | |
| <body> | |
| <h1>WavVAE Comparison</h1> | |
| <div class="cards">{"".join(cards)}</div> | |
| <table> | |
| <thead><tr>{th}</tr></thead> | |
| <tbody>{"".join(rows)}</tbody> | |
| </table> | |
| </body> | |
| </html> | |
| """ | |
| out = output_dir / "index.html" | |
| out.write_text(html, encoding="utf-8") | |
| return str(out) | |
| # ------------------------- | |
| # Core | |
| # ------------------------- | |
| def reconstruct_wavvae( | |
| wav_mono: torch.Tensor, wavvae: WavVAE, device: str | |
| ) -> torch.Tensor: | |
| x = wav_mono.to(device) # (1,T) | |
| z = wavvae.encode(x) | |
| wav_hat = wavvae.decode(z) # (1,1,T) | |
| return wav_hat.squeeze(0).squeeze(0).detach() | |
| def parse_models_manifest(path: str) -> List[WavVaeSpec]: | |
| """ | |
| JSON list of: | |
| {"name": "...", "wavvae": "/path/to/WavVAE_dir"} | |
| """ | |
| raw = json.loads(Path(path).read_text(encoding="utf-8")) | |
| specs = [] | |
| for it in raw: | |
| specs.append(WavVaeSpec(name=it["name"], wavvae_dir=it["wavvae"])) | |
| return specs | |
| def main(): | |
| ap = argparse.ArgumentParser( | |
| description="Compare WavVAE checkpoints and generate a static HTML page." | |
| ) | |
| ap.add_argument("--models", required=True, help="JSON manifest of WavVAE models.") | |
| ap.add_argument( | |
| "--audio_manifest", required=True, help="TXT file: one audio path per line." | |
| ) | |
| ap.add_argument("--out", default="compare_wavvae_out") | |
| ap.add_argument("--device", default="cuda") | |
| ap.add_argument("--force", action="store_true") | |
| args = ap.parse_args() | |
| device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" | |
| out_dir = Path(args.out) | |
| (out_dir / "original").mkdir(parents=True, exist_ok=True) | |
| recon_dir = out_dir / "recon" | |
| recon_dir.mkdir(parents=True, exist_ok=True) | |
| specs = parse_models_manifest(args.models) | |
| if not specs: | |
| print("No models.", file=sys.stderr) | |
| sys.exit(1) | |
| # load models | |
| wavvae_by_name: Dict[str, WavVAE] = {} | |
| sr_by_model: Dict[str, int] = {} | |
| wavvae_cfg: Dict[str, Dict[str, Any]] = {} | |
| for s in specs: | |
| print(f"[Load] {s.name}") | |
| w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device) | |
| wavvae_by_name[s.name] = w | |
| sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000)) | |
| wavvae_cfg[s.name] = read_config_any(s.wavvae_dir) | |
| audio_paths = read_audio_manifest(args.audio_manifest) | |
| # normalize originals to wav+mono (browser-friendly); keep native sr for original column | |
| actual_audio = [] | |
| for ap in audio_paths: | |
| if not ap.exists(): | |
| print(f"[Skip missing] {ap}", file=sys.stderr) | |
| continue | |
| wav, sr = ta_load(str(ap)) | |
| wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) | |
| out_orig = out_dir / "original" / (ap.stem + ".wav") | |
| if args.force or not out_orig.exists(): | |
| save_wav(out_orig, wav_mono, sr) | |
| actual_audio.append(out_orig) | |
| # recon per model | |
| for out_orig in actual_audio: | |
| wav0, sr0 = ta_load(str(out_orig)) | |
| if wav0.size(0) > 1: | |
| wav0 = wav0.mean(dim=0, keepdim=True) | |
| for s in specs: | |
| target_sr = sr_by_model[s.name] | |
| wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0 | |
| out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav" | |
| if args.force or not out_path.exists(): | |
| print(f"[Reconstruct] {s.name} ← {out_orig.name}") | |
| wav_hat = reconstruct_wavvae(wav_in, wavvae_by_name[s.name], device) | |
| save_wav(out_path, wav_hat, target_sr) | |
| html_path = make_html(out_dir, actual_audio, specs, sr_by_model, wavvae_cfg) | |
| print(f"Done. Open: {html_path}") | |
| if __name__ == "__main__": | |
| main() | |