Spaces:
Paused
Paused
| import os | |
| from audioldm_eval.datasets.load_mel import load_npy_data, MelPairedDataset, WaveDataset | |
| import numpy as np | |
| import argparse | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from audioldm_eval.metrics.fad import FrechetAudioDistance | |
| from audioldm_eval import calculate_fid, calculate_isc, calculate_kid, calculate_kl | |
| from skimage.metrics import peak_signal_noise_ratio as psnr | |
| from skimage.metrics import structural_similarity as ssim | |
| from audioldm_eval.feature_extractors.panns import Cnn14 | |
| from audioldm_eval.audio.tools import save_pickle, load_pickle, write_json, load_json | |
| from ssr_eval.metrics import AudioMetrics | |
| import audioldm_eval.audio as Audio | |
| import time | |
| class EvaluationHelper: | |
| def __init__(self, sampling_rate, device, backbone="cnn14") -> None: | |
| self.device = device | |
| self.backbone = backbone | |
| self.sampling_rate = sampling_rate | |
| self.frechet = FrechetAudioDistance( | |
| use_pca=False, | |
| use_activation=False, | |
| verbose=False, | |
| ) | |
| self.lsd_metric = AudioMetrics(self.sampling_rate) | |
| self.frechet.model = self.frechet.model.to(device) | |
| features_list = ["2048", "logits"] | |
| if self.sampling_rate == 16000: | |
| self.mel_model = Cnn14( | |
| features_list=features_list, | |
| sample_rate=16000, | |
| window_size=512, | |
| hop_size=160, | |
| mel_bins=64, | |
| fmin=50, | |
| fmax=8000, | |
| classes_num=527, | |
| ) | |
| elif self.sampling_rate == 32000: | |
| self.mel_model = Cnn14( | |
| features_list=features_list, | |
| sample_rate=32000, | |
| window_size=1024, | |
| hop_size=320, | |
| mel_bins=64, | |
| fmin=50, | |
| fmax=14000, | |
| classes_num=527, | |
| ) | |
| else: | |
| raise ValueError( | |
| "We only support the evaluation on 16kHz and 32kHz sampling rate." | |
| ) | |
| if self.sampling_rate == 16000: | |
| self._stft = Audio.TacotronSTFT(512, 160, 512, 64, 16000, 50, 8000) | |
| elif self.sampling_rate == 32000: | |
| self._stft = Audio.TacotronSTFT(1024, 320, 1024, 64, 32000, 50, 14000) | |
| else: | |
| raise ValueError( | |
| "We only support the evaluation on 16kHz and 32kHz sampling rate." | |
| ) | |
| self.mel_model.eval() | |
| self.mel_model.to(self.device) | |
| self.fbin_mean, self.fbin_std = None, None | |
| def main( | |
| self, | |
| generate_files_path, | |
| groundtruth_path, | |
| limit_num=None, | |
| ): | |
| self.file_init_check(generate_files_path) | |
| self.file_init_check(groundtruth_path) | |
| same_name = self.get_filename_intersection_ratio( | |
| generate_files_path, groundtruth_path, limit_num=limit_num | |
| ) | |
| metrics = self.calculate_metrics(generate_files_path, groundtruth_path, same_name, limit_num) | |
| return metrics | |
| def file_init_check(self, dir): | |
| assert os.path.exists(dir), "The path does not exist %s" % dir | |
| assert len(os.listdir(dir)) > 1, "There is no files in %s" % dir | |
| def get_filename_intersection_ratio( | |
| self, dir1, dir2, threshold=0.99, limit_num=None | |
| ): | |
| self.datalist1 = [os.path.join(dir1, x) for x in os.listdir(dir1)] | |
| self.datalist1 = sorted(self.datalist1) | |
| self.datalist1 = [item for item in self.datalist1 if item.endswith(".wav")] | |
| self.datalist2 = [os.path.join(dir2, x) for x in os.listdir(dir2)] | |
| self.datalist2 = sorted(self.datalist2) | |
| self.datalist2 = [item for item in self.datalist2 if item.endswith(".wav")] | |
| data_dict1 = {os.path.basename(x): x for x in self.datalist1} | |
| data_dict2 = {os.path.basename(x): x for x in self.datalist2} | |
| keyset1 = set(data_dict1.keys()) | |
| keyset2 = set(data_dict2.keys()) | |
| intersect_keys = keyset1.intersection(keyset2) | |
| if ( | |
| len(intersect_keys) / len(keyset1) > threshold | |
| and len(intersect_keys) / len(keyset2) > threshold | |
| ): | |
| ''' | |
| print( | |
| "+Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=True" | |
| % (len(intersect_keys), len(keyset1), len(keyset2)) | |
| ) | |
| ''' | |
| return True | |
| else: | |
| ''' | |
| print( | |
| "-Two path have %s intersection files out of total %s & %s files. Processing two folder with same_name=False" | |
| % (len(intersect_keys), len(keyset1), len(keyset2)) | |
| ) | |
| ''' | |
| return False | |
| def calculate_lsd(self, pairedloader, same_name=True, time_offset=160 * 7): | |
| if same_name == False: | |
| return { | |
| "lsd": -1, | |
| "ssim_stft": -1, | |
| } | |
| # print("Calculating LSD using a time offset of %s ..." % time_offset) | |
| lsd_avg = [] | |
| ssim_stft_avg = [] | |
| for _, _, filename, (audio1, audio2) in tqdm(pairedloader, leave=False): | |
| audio1 = audio1.cpu().numpy()[0, 0] | |
| audio2 = audio2.cpu().numpy()[0, 0] | |
| # If you use HIFIGAN (verified on 2023-01-12), you need seven frames' offset | |
| audio1 = audio1[time_offset:] | |
| audio1 = audio1 - np.mean(audio1) | |
| audio2 = audio2 - np.mean(audio2) | |
| audio1 = audio1 / np.max(np.abs(audio1)) | |
| audio2 = audio2 / np.max(np.abs(audio2)) | |
| min_len = min(audio1.shape[0], audio2.shape[0]) | |
| audio1, audio2 = audio1[:min_len], audio2[:min_len] | |
| try: | |
| result = self.lsd(audio1, audio2) | |
| lsd_avg.append(result["lsd"]) | |
| ssim_stft_avg.append(result["ssim"]) | |
| except: | |
| continue | |
| return {"lsd": np.mean(lsd_avg), "ssim_stft": np.mean(ssim_stft_avg)} | |
| def lsd(self, audio1, audio2): | |
| result = self.lsd_metric.evaluation(audio1, audio2, None) | |
| return result | |
| def calculate_psnr_ssim(self, pairedloader, same_name=True): | |
| if same_name == False: | |
| return {"psnr": -1, "ssim": -1} | |
| psnr_avg = [] | |
| ssim_avg = [] | |
| for mel_gen, mel_target, filename, _ in tqdm(pairedloader, leave=False): | |
| mel_gen = mel_gen.cpu().numpy()[0] | |
| mel_target = mel_target.cpu().numpy()[0] | |
| psnrval = psnr(mel_gen, mel_target) | |
| if np.isinf(psnrval): | |
| print("Infinite value encountered in psnr %s " % filename) | |
| continue | |
| psnr_avg.append(psnrval) | |
| ssim_avg.append(ssim(mel_gen, mel_target)) | |
| return {"psnr": np.mean(psnr_avg), "ssim": np.mean(ssim_avg)} | |
| def calculate_metrics(self, generate_files_path, groundtruth_path, same_name, limit_num=None): | |
| # Generation, target | |
| torch.manual_seed(0) | |
| num_workers = 0 | |
| outputloader = DataLoader( | |
| WaveDataset( | |
| generate_files_path, | |
| self.sampling_rate, | |
| limit_num=limit_num, | |
| ), | |
| batch_size=1, | |
| sampler=None, | |
| num_workers=num_workers, | |
| ) | |
| resultloader = DataLoader( | |
| WaveDataset( | |
| groundtruth_path, | |
| self.sampling_rate, | |
| limit_num=limit_num, | |
| ), | |
| batch_size=1, | |
| sampler=None, | |
| num_workers=num_workers, | |
| ) | |
| pairedloader = DataLoader( | |
| MelPairedDataset( | |
| generate_files_path, | |
| groundtruth_path, | |
| self._stft, | |
| self.sampling_rate, | |
| self.fbin_mean, | |
| self.fbin_std, | |
| limit_num=limit_num, | |
| ), | |
| batch_size=1, | |
| sampler=None, | |
| num_workers=16, | |
| ) | |
| out = {} | |
| metric_lsd = self.calculate_lsd(pairedloader, same_name=same_name) | |
| out.update(metric_lsd) | |
| featuresdict_2 = self.get_featuresdict(resultloader) | |
| featuresdict_1 = self.get_featuresdict(outputloader) | |
| # if cfg.have_kl: | |
| metric_psnr_ssim = self.calculate_psnr_ssim(pairedloader, same_name=same_name) | |
| out.update(metric_psnr_ssim) | |
| metric_kl, kl_ref, paths_1 = calculate_kl( | |
| featuresdict_1, featuresdict_2, "logits", same_name | |
| ) | |
| out.update(metric_kl) | |
| metric_isc = calculate_isc( | |
| featuresdict_1, | |
| feat_layer_name="logits", | |
| splits=10, | |
| samples_shuffle=True, | |
| rng_seed=2020, | |
| ) | |
| out.update(metric_isc) | |
| metric_fid = calculate_fid( | |
| featuresdict_1, featuresdict_2, feat_layer_name="2048" | |
| ) | |
| out.update(metric_fid) | |
| # Gen, target | |
| fad_score = self.frechet.score(generate_files_path, groundtruth_path, limit_num=limit_num) | |
| out.update(fad_score) | |
| metric_kid = calculate_kid( | |
| featuresdict_1, | |
| featuresdict_2, | |
| feat_layer_name="2048", | |
| subsets=100, | |
| subset_size=1000, | |
| degree=3, | |
| gamma=None, | |
| coef0=1, | |
| rng_seed=2020, | |
| ) | |
| out.update(metric_kid) | |
| ''' | |
| print("\n".join((f"{k}: {v:.7f}" for k, v in out.items()))) | |
| print("\n") | |
| print(limit_num) | |
| print( | |
| f'KL_Sigmoid: {out.get("kullback_leibler_divergence_sigmoid", float("nan")):8.5f};', | |
| f'KL: {out.get("kullback_leibler_divergence_softmax", float("nan")):8.5f};', | |
| f'PSNR: {out.get("psnr", float("nan")):.5f}', | |
| f'SSIM: {out.get("ssim", float("nan")):.5f}', | |
| f'ISc: {out.get("inception_score_mean", float("nan")):8.5f} ({out.get("inception_score_std", float("nan")):5f});', | |
| f'KID: {out.get("kernel_inception_distance_mean", float("nan")):.5f}', | |
| f'({out.get("kernel_inception_distance_std", float("nan")):.5f})', | |
| f'FD: {out.get("frechet_distance", float("nan")):8.5f};', | |
| f'FAD: {out.get("frechet_audio_distance", float("nan")):.5f}', | |
| f'LSD: {out.get("lsd", float("nan")):.5f}', | |
| f'SSIM_STFT: {out.get("ssim_stft", float("nan")):.5f}', | |
| ) | |
| ''' | |
| result = { | |
| "frechet_distance": out.get("frechet_distance", float("nan")), | |
| "frechet_audio_distance": out.get("frechet_audio_distance", float("nan")), | |
| "kl_sigmoid": out.get( | |
| "kullback_leibler_divergence_sigmoid", float("nan") | |
| ), | |
| "kl_softmax": out.get( | |
| "kullback_leibler_divergence_softmax", float("nan") | |
| ), | |
| "lsd": out.get("lsd", float("nan")), | |
| "psnr": out.get("psnr", float("nan")), | |
| "ssim": out.get("ssim", float("nan")), | |
| "ssim_stft": out.get("ssim_stft", float("nan")), | |
| "is_mean": out.get("inception_score_mean", float("nan")), | |
| "is_std": out.get("inception_score_std", float("nan")), | |
| "kid_mean": out.get( | |
| "kernel_inception_distance_mean", float("nan") | |
| ), | |
| "kid_std": out.get( | |
| "kernel_inception_distance_std", float("nan") | |
| ), | |
| } | |
| result = {k: round(v, 4) for k, v in result.items()} | |
| json_path = generate_files_path + "_evaluation_results.json" | |
| write_json(result, json_path) | |
| return result | |
| def get_featuresdict(self, dataloader): | |
| out = None | |
| out_meta = None | |
| # transforms=StandardNormalizeAudio() | |
| for waveform, filename in tqdm(dataloader, leave=False): | |
| try: | |
| metadict = { | |
| "file_path_": filename, | |
| } | |
| waveform = waveform.squeeze(1) | |
| # batch = transforms(batch) | |
| waveform = waveform.float().to(self.device) | |
| with torch.no_grad(): | |
| featuresdict = self.mel_model(waveform) | |
| # featuresdict = self.mel_model.convert_features_tuple_to_dict(features) | |
| featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()} | |
| if out is None: | |
| out = featuresdict | |
| else: | |
| out = {k: out[k] + featuresdict[k] for k in out.keys()} | |
| if out_meta is None: | |
| out_meta = metadict | |
| else: | |
| out_meta = {k: out_meta[k] + metadict[k] for k in out_meta.keys()} | |
| except Exception as e: | |
| import ipdb | |
| ipdb.set_trace() | |
| print("PANNs Inference error: ", e) | |
| continue | |
| out = {k: torch.cat(v, dim=0) for k, v in out.items()} | |
| return {**out, **out_meta} | |
| def sample_from(self, samples, number_to_use): | |
| assert samples.shape[0] >= number_to_use | |
| rand_order = np.random.permutation(samples.shape[0]) | |
| return samples[rand_order[: samples.shape[0]], :] | |
| ''' | |
| if __name__ == "__main__": | |
| import yaml | |
| import argparse | |
| from audioldm_eval import EvaluationHelper | |
| import torch | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-g", | |
| "--generation_result_path", | |
| type=str, | |
| required=False, | |
| help="Audio sampling rate during evaluation", | |
| default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/balanced_train_segments", | |
| ) | |
| parser.add_argument( | |
| "-t", | |
| "--target_audio_path", | |
| type=str, | |
| required=False, | |
| help="Audio sampling rate during evaluation", | |
| default="/mnt/fast/datasets/audio/audioset/2million_audioset_wav/eval_segments", | |
| ) | |
| parser.add_argument( | |
| "-sr", | |
| "--sampling_rate", | |
| type=int, | |
| required=False, | |
| help="Audio sampling rate during evaluation", | |
| default=16000, | |
| ) | |
| parser.add_argument( | |
| "-l", | |
| "--limit_num", | |
| type=int, | |
| required=False, | |
| help="Audio clip numbers limit for evaluation", | |
| default=None, | |
| ) | |
| args = parser.parse_args() | |
| device = torch.device(f"cuda:{0}") | |
| evaluator = EvaluationHelper(args.sampling_rate, device) | |
| metrics = evaluator.main( | |
| args.generation_result_path, | |
| args.target_audio_path, | |
| limit_num=args.limit_num, | |
| same_name=args.same_name, | |
| ) | |
| print(metrics) | |
| ''' |