import argparse from pathlib import Path import librosa import soundfile as sf import torch from datasets import load_dataset, load_from_disk from zcodec.models import WavVAE def load_and_resample(path, target_sr): wav, sr = sf.read(str(path)) if sr != target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] return wav def main(): parser = argparse.ArgumentParser( description="Encode HF dataset audio with WavVAE using map() (non-batched)." ) parser.add_argument("dataset", type=str, help="Path or HF hub ID of dataset") parser.add_argument("path_column", type=str, help="Column name with wav file paths") parser.add_argument( "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" ) parser.add_argument( "--split", type=str, default=None, help="Dataset split (if loading from hub)" ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) parser.add_argument( "--num_proc", type=int, default=1, help="Number of processes for map()" ) args = parser.parse_args() device = torch.device(args.device) # Load model wavvae = WavVAE.from_pretrained_local(args.checkpoint).to(device).eval() target_sr = wavvae.sampling_rate # Load dataset if Path(args.dataset).exists(): ds = load_from_disk(args.dataset) else: ds = load_dataset(args.dataset, split=args.split or "train") ds = ds.filter(lambda x: x > 1.0, input_columns="duration") # Mapping function (non-batched) @torch.no_grad() def encode_example(example): wav = load_and_resample(example[args.path_column], target_sr).to(device) latent = wavvae.encode(wav).cpu().numpy() example["audio_z"] = latent return example # Apply map without batching ds = ds.map( encode_example, num_proc=args.num_proc, ) # Save dataset with new column ds.save_to_disk(str(Path(args.dataset) + "_with_latents")) if __name__ == "__main__": main()