#!/usr/bin/env python3 import argparse import json import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torchaudio from torchaudio import load as ta_load from torchaudio.functional import resample as ta_resample from zcodec.models import WavVAE # ------------------------- # Data structures # ------------------------- @dataclass class WavVaeSpec: name: str wavvae_dir: str # ------------------------- # Utilities # ------------------------- def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: if path.is_file(): try: return json.load(path.open("r", encoding="utf-8")) except Exception: return None return None def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: cand = [ Path(checkpoint_dir) / "config.json", Path(checkpoint_dir) / "model_config.json", Path(checkpoint_dir) / "config.yaml", # shown as path only ] for p in cand: if p.exists(): if p.suffix == ".json": j = load_json_if_exists(p) if j is not None: return j else: return {"_config_file": str(p)} return {} def sanitize_name(s: str) -> str: return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) def ensure_mono_and_resample( wav: torch.Tensor, sr: int, target_sr: int ) -> Tuple[torch.Tensor, int]: if wav.ndim != 2: raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}") if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) if sr != target_sr: wav = ta_resample(wav, sr, target_sr) sr = target_sr return wav.to(torch.float32), sr def save_wav(path: Path, wav: torch.Tensor, sr: int): path.parent.mkdir(parents=True, exist_ok=True) if wav.ndim == 1: wav = wav.unsqueeze(0) wav = wav.clamp(-1, 1).contiguous().cpu() torchaudio.save( str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 ) def read_audio_manifest(txt_path: str) -> List[Path]: lines = Path(txt_path).read_text(encoding="utf-8").splitlines() files = [ Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#") ] return files def html_escape(s: str) -> str: return ( s.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def make_html( output_dir: Path, audio_files: List[Path], specs: List[WavVaeSpec], sr_by_model: Dict[str, int], wavvae_cfg: Dict[str, Dict[str, Any]], ) -> str: def player(src_rel: str) -> str: return f'' # cards cards = [] for s in specs: cfg = wavvae_cfg.get(s.name, {}) cfg_short = json.dumps(cfg if cfg else {"_": "no JSON config found"}, indent=2)[ :1200 ] card = f"""

{html_escape(s.name)}

Sample rate: {sr_by_model.get(s.name, "N/A")} Hz

WavVAE config
{html_escape(cfg_short)}
""" cards.append(card) css = """ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } table { border-collapse: collapse; width: 100%; } th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } th { background: #fafafa; position: sticky; top: 0; } audio { width: 260px; } """ th = "InputOriginal" + "".join( f"{html_escape(s.name)}" for s in specs ) rows = [] for af in audio_files: base = af.stem orig_rel = f"original/{html_escape(af.name)}" tds = [f"{html_escape(base)}", f"{player(orig_rel)}"] for s in specs: rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav" tds.append(f"{player(rec_rel)}") rows.append("" + "".join(tds) + "") html = f""" WavVAE Comparison

WavVAE Comparison

{"".join(cards)}
{th}{"".join(rows)}
""" out = output_dir / "index.html" out.write_text(html, encoding="utf-8") return str(out) # ------------------------- # Core # ------------------------- @torch.inference_mode() def reconstruct_wavvae( wav_mono: torch.Tensor, wavvae: WavVAE, device: str ) -> torch.Tensor: x = wav_mono.to(device) # (1,T) z = wavvae.encode(x) wav_hat = wavvae.decode(z) # (1,1,T) return wav_hat.squeeze(0).squeeze(0).detach() def parse_models_manifest(path: str) -> List[WavVaeSpec]: """ JSON list of: {"name": "...", "wavvae": "/path/to/WavVAE_dir"} """ raw = json.loads(Path(path).read_text(encoding="utf-8")) specs = [] for it in raw: specs.append(WavVaeSpec(name=it["name"], wavvae_dir=it["wavvae"])) return specs def main(): ap = argparse.ArgumentParser( description="Compare WavVAE checkpoints and generate a static HTML page." ) ap.add_argument("--models", required=True, help="JSON manifest of WavVAE models.") ap.add_argument( "--audio_manifest", required=True, help="TXT file: one audio path per line." ) ap.add_argument("--out", default="compare_wavvae_out") ap.add_argument("--device", default="cuda") ap.add_argument("--force", action="store_true") args = ap.parse_args() device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" out_dir = Path(args.out) (out_dir / "original").mkdir(parents=True, exist_ok=True) recon_dir = out_dir / "recon" recon_dir.mkdir(parents=True, exist_ok=True) specs = parse_models_manifest(args.models) if not specs: print("No models.", file=sys.stderr) sys.exit(1) # load models wavvae_by_name: Dict[str, WavVAE] = {} sr_by_model: Dict[str, int] = {} wavvae_cfg: Dict[str, Dict[str, Any]] = {} for s in specs: print(f"[Load] {s.name}") w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device) wavvae_by_name[s.name] = w sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000)) wavvae_cfg[s.name] = read_config_any(s.wavvae_dir) audio_paths = read_audio_manifest(args.audio_manifest) # normalize originals to wav+mono (browser-friendly); keep native sr for original column actual_audio = [] for ap in audio_paths: if not ap.exists(): print(f"[Skip missing] {ap}", file=sys.stderr) continue wav, sr = ta_load(str(ap)) wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) out_orig = out_dir / "original" / (ap.stem + ".wav") if args.force or not out_orig.exists(): save_wav(out_orig, wav_mono, sr) actual_audio.append(out_orig) # recon per model for out_orig in actual_audio: wav0, sr0 = ta_load(str(out_orig)) if wav0.size(0) > 1: wav0 = wav0.mean(dim=0, keepdim=True) for s in specs: target_sr = sr_by_model[s.name] wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0 out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav" if args.force or not out_path.exists(): print(f"[Reconstruct] {s.name} ← {out_orig.name}") wav_hat = reconstruct_wavvae(wav_in, wavvae_by_name[s.name], device) save_wav(out_path, wav_hat, target_sr) html_path = make_html(out_dir, actual_audio, specs, sr_by_model, wavvae_cfg) print(f"Done. Open: {html_path}") if __name__ == "__main__": main()