File size: 2,170 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
from pathlib import Path

import librosa
import soundfile as sf
import torch

from datasets import load_dataset, load_from_disk
from zcodec.models import WavVAE


def load_and_resample(path, target_sr):
    wav, sr = sf.read(str(path))
    if sr != target_sr:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
    wav = torch.tensor(wav).unsqueeze(0).float()  # shape: [1, T]
    return wav


def main():
    parser = argparse.ArgumentParser(
        description="Encode HF dataset audio with WavVAE using map() (non-batched)."
    )
    parser.add_argument("dataset", type=str, help="Path or HF hub ID of dataset")
    parser.add_argument("path_column", type=str, help="Column name with wav file paths")
    parser.add_argument(
        "checkpoint", type=Path, help="Path to WavVAE checkpoint directory"
    )
    parser.add_argument(
        "--split", type=str, default=None, help="Dataset split (if loading from hub)"
    )
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument(
        "--num_proc", type=int, default=1, help="Number of processes for map()"
    )
    args = parser.parse_args()

    device = torch.device(args.device)

    # Load model
    wavvae = WavVAE.from_pretrained_local(args.checkpoint).to(device).eval()
    target_sr = wavvae.sampling_rate

    # Load dataset
    if Path(args.dataset).exists():
        ds = load_from_disk(args.dataset)
    else:
        ds = load_dataset(args.dataset, split=args.split or "train")

    ds = ds.filter(lambda x: x > 1.0, input_columns="duration")

    # Mapping function (non-batched)
    @torch.no_grad()
    def encode_example(example):
        wav = load_and_resample(example[args.path_column], target_sr).to(device)
        latent = wavvae.encode(wav).cpu().numpy()
        example["audio_z"] = latent
        return example

    # Apply map without batching
    ds = ds.map(
        encode_example,
        num_proc=args.num_proc,
    )

    # Save dataset with new column
    ds.save_to_disk(str(Path(args.dataset) + "_with_latents"))


if __name__ == "__main__":
    main()