Spaces:
Paused
Paused
| """ | |
| Calculate Frechet Audio Distance betweeen two audio directories. | |
| Frechet distance implementation adapted from: https://github.com/mseitzer/pytorch-fid | |
| VGGish adapted from: https://github.com/harritaylor/torchvggish | |
| """ | |
| import os | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from scipy import linalg | |
| from tqdm import tqdm | |
| import soundfile as sf | |
| import resampy | |
| from multiprocessing.dummy import Pool as ThreadPool | |
| SAMPLE_RATE = 16000 | |
| def load_audio_task(fname): | |
| try: | |
| wav_data, sr = sf.read(fname, dtype="int16") | |
| except Exception as e: | |
| print(e) | |
| wav_data = np.zeros(160000) | |
| sr = 16000 | |
| assert wav_data.dtype == np.int16, "Bad sample type: %r" % wav_data.dtype | |
| wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0] | |
| # Convert to mono | |
| if len(wav_data.shape) > 1: | |
| wav_data = np.mean(wav_data, axis=1) | |
| if sr != SAMPLE_RATE: | |
| if SAMPLE_RATE == 16000 and sr == 32000: | |
| wav_data = wav_data[::2] | |
| else: | |
| wav_data = resampy.resample(wav_data, sr, SAMPLE_RATE) | |
| return wav_data, SAMPLE_RATE | |
| class FrechetAudioDistance: | |
| def __init__( | |
| self, use_pca=False, use_activation=False, verbose=False, audio_load_worker=8 | |
| ): | |
| self.__get_model(use_pca=use_pca, use_activation=use_activation) | |
| self.verbose = verbose | |
| self.audio_load_worker = audio_load_worker | |
| def __get_model(self, use_pca=False, use_activation=False): | |
| """ | |
| Params: | |
| -- x : Either | |
| (i) a string which is the directory of a set of audio files, or | |
| (ii) a np.ndarray of shape (num_samples, sample_length) | |
| """ | |
| self.model = torch.hub.load("harritaylor/torchvggish", "vggish") | |
| if not use_pca: | |
| self.model.postprocess = False | |
| if not use_activation: | |
| self.model.embeddings = nn.Sequential( | |
| *list(self.model.embeddings.children())[:-1] | |
| ) | |
| self.model.eval() | |
| def get_embeddings(self, x, sr=16000, limit_num=None): | |
| """ | |
| Get embeddings using VGGish model. | |
| Params: | |
| -- x : Either | |
| (i) a string which is the directory of a set of audio files, or | |
| (ii) a list of np.ndarray audio samples | |
| -- sr : Sampling rate, if x is a list of audio samples. Default value is 16000. | |
| """ | |
| embd_lst = [] | |
| if isinstance(x, list): | |
| try: | |
| for audio, sr in tqdm(x, disable=(not self.verbose)): | |
| embd = self.model.forward(audio, sr) | |
| if self.model.device == torch.device("cuda"): | |
| embd = embd.cpu() | |
| embd = embd.detach().numpy() | |
| embd_lst.append(embd) | |
| except Exception as e: | |
| print( | |
| "[Frechet Audio Distance] get_embeddings throw an exception: {}".format( | |
| str(e) | |
| ) | |
| ) | |
| elif isinstance(x, str): | |
| if self.verbose: | |
| print("Calculating the embedding of the audio files inside %s" % x) | |
| try: | |
| for i, fname in tqdm( | |
| enumerate(os.listdir(x)), disable=(not self.verbose) | |
| ): | |
| if fname.endswith(".wav"): | |
| if limit_num is not None and i > limit_num: | |
| break | |
| try: | |
| audio, sr = load_audio_task(os.path.join(x, fname)) | |
| embd = self.model.forward(audio, sr) | |
| if self.model.device == torch.device("cuda"): | |
| embd = embd.cpu() | |
| embd = embd.detach().numpy() | |
| embd_lst.append(embd) | |
| except Exception as e: | |
| print(e, fname) | |
| continue | |
| except Exception as e: | |
| print( | |
| "[Frechet Audio Distance] get_embeddings throw an exception: {}".format( | |
| str(e) | |
| ) | |
| ) | |
| else: | |
| raise AttributeError | |
| return np.concatenate(embd_lst, axis=0) | |
| def calculate_embd_statistics(self, embd_lst): | |
| if isinstance(embd_lst, list): | |
| embd_lst = np.array(embd_lst) | |
| mu = np.mean(embd_lst, axis=0) | |
| sigma = np.cov(embd_lst, rowvar=False) | |
| return mu, sigma | |
| def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): | |
| """ | |
| Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py | |
| Numpy implementation of the Frechet Distance. | |
| The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) | |
| and X_2 ~ N(mu_2, C_2) is | |
| d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). | |
| Stable version by Dougal J. Sutherland. | |
| Params: | |
| -- mu1 : Numpy array containing the activations of a layer of the | |
| inception net (like returned by the function 'get_predictions') | |
| for generated samples. | |
| -- mu2 : The sample mean over activations, precalculated on an | |
| representative data set. | |
| -- sigma1: The covariance matrix over activations for generated samples. | |
| -- sigma2: The covariance matrix over activations, precalculated on an | |
| representative data set. | |
| Returns: | |
| -- : The Frechet Distance. | |
| """ | |
| mu1 = np.atleast_1d(mu1) | |
| mu2 = np.atleast_1d(mu2) | |
| sigma1 = np.atleast_2d(sigma1) | |
| sigma2 = np.atleast_2d(sigma2) | |
| assert ( | |
| mu1.shape == mu2.shape | |
| ), "Training and test mean vectors have different lengths" | |
| assert ( | |
| sigma1.shape == sigma2.shape | |
| ), "Training and test covariances have different dimensions" | |
| diff = mu1 - mu2 | |
| # Product might be almost singular | |
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
| if not np.isfinite(covmean).all(): | |
| msg = ( | |
| "fid calculation produces singular product; " | |
| "adding %s to diagonal of cov estimates" | |
| ) % eps | |
| print(msg) | |
| offset = np.eye(sigma1.shape[0]) * eps | |
| covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
| # Numerical error might give slight imaginary component | |
| if np.iscomplexobj(covmean): | |
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
| m = np.max(np.abs(covmean.imag)) | |
| raise ValueError("Imaginary component {}".format(m)) | |
| covmean = covmean.real | |
| tr_covmean = np.trace(covmean) | |
| return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean | |
| def __load_audio_files(self, dir): | |
| task_results = [] | |
| pool = ThreadPool(self.audio_load_worker) | |
| pbar = tqdm(total=len(os.listdir(dir)), disable=(not self.verbose)) | |
| def update(*a): | |
| pbar.update() | |
| if self.verbose: | |
| print("[Frechet Audio Distance] Loading audio from {}...".format(dir)) | |
| for fname in os.listdir(dir): | |
| res = pool.apply_async( | |
| load_audio_task, args=(os.path.join(dir, fname),), callback=update | |
| ) | |
| task_results.append(res) | |
| pool.close() | |
| pool.join() | |
| return [k.get() for k in task_results] | |
| def score(self, background_dir, eval_dir, store_embds=False, limit_num=None): | |
| # background_dir: generated samples | |
| # eval_dir: groundtruth samples | |
| try: | |
| # audio_background = self.__load_audio_files(background_dir) | |
| # audio_eval = self.__load_audio_files(eval_dir) | |
| embds_background = self.get_embeddings(background_dir, limit_num=limit_num) | |
| embds_eval = self.get_embeddings(eval_dir, limit_num=limit_num) | |
| if store_embds: | |
| np.save("embds_background.npy", embds_background) | |
| np.save("embds_eval.npy", embds_eval) | |
| if len(embds_background) == 0: | |
| print( | |
| "[Frechet Audio Distance] background set dir is empty, exitting..." | |
| ) | |
| return -1 | |
| if len(embds_eval) == 0: | |
| print("[Frechet Audio Distance] eval set dir is empty, exitting...") | |
| return -1 | |
| mu_background, sigma_background = self.calculate_embd_statistics( | |
| embds_background | |
| ) | |
| mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval) | |
| fad_score = self.calculate_frechet_distance( | |
| mu_background, sigma_background, mu_eval, sigma_eval | |
| ) | |
| return {"frechet_audio_distance": fad_score} | |
| except Exception as e: | |
| print("[Frechet Audio Distance] exception thrown, {}".format(str(e))) | |
| return -1 | |