import argparse from pathlib import Path import librosa import soundfile as sf import torch from safetensors.torch import save_file from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from transformers import Wav2Vec2Model class AudioDataset(Dataset): def __init__(self, file_list, target_sr=16000): self.paths = file_list self.target_sr = target_sr def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] wav, sr = sf.read(str(path)) if sr != self.target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T] return wav, path @torch.no_grad() def encode_batch(model, batch, device, out_dir, keep_layers): wavs, paths = batch for wav, path in zip(wavs, paths): wav = wav.to(device) outputs = model(wav, output_hidden_states=True) hidden_states = outputs.hidden_states # tuple of 25 tensors: [1, T', D] selected = { f"layer_{i}": hs.squeeze(0).cpu() for i, hs in enumerate(hidden_states) if i in keep_layers } out_path = out_dir / (path.stem + ".st") save_file(selected, str(out_path)) def parse_layers(layer_str): if layer_str.strip().lower() == "all": return set(range(25)) return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit()) def main(): parser = argparse.ArgumentParser(description="Infer Wav2Vec2 hidden states.") parser.add_argument( "file_list", type=Path, help="Text file with paths to audio files" ) parser.add_argument("output_dir", type=Path, help="Directory to save .st files") parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) parser.add_argument("--num_workers", type=int, default=2) parser.add_argument( "--layers", type=str, default="all", help="Comma-separated layer indices or 'all'", ) args = parser.parse_args() keep_layers = parse_layers(args.layers) device = torch.device(args.device) model = ( Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53") .to(device) .eval() ) with open(args.file_list, "r") as f: paths = [Path(line.strip()) for line in f if line.strip()] dataset = AudioDataset(paths) dataloader = DataLoader( dataset, batch_size=1, num_workers=args.num_workers, collate_fn=lambda x: list(zip(*x)), ) args.output_dir.mkdir(parents=True, exist_ok=True) for batch in tqdm(dataloader): try: encode_batch(model, batch, device, args.output_dir, keep_layers) except Exception as e: print(f"❌ Failed on batch: {e}") if __name__ == "__main__": main()