Spaces:
Paused
Paused
| import os | |
| import yaml | |
| import random | |
| import torch | |
| import torchaudio | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from .processors import NaiveAudioProcessor, WaveformAudioProcessor, FbankAudioProcessor | |
| from ..utils import load_json | |
| def label2caption(label, background_sound=None, template="{} can be heard"): | |
| r"""This is a helper function converting list of labels to captions.""" | |
| if background_sound is None: | |
| return [template.format(", ".join(l)) for l in label] | |
| if isinstance(background_sound, str): | |
| background_sound = [[background_sound]] * len(label) | |
| assert len(label) == len( | |
| background_sound | |
| ), "the number of `background_sound` should match the number of `label`." | |
| caption = [] | |
| for l, bg in zip(label, background_sound): | |
| cap = template.format(", ".join(l)) | |
| cap += " with the background sounds of {}".format(", ".join(bg)) | |
| caption.append(cap) | |
| return caption | |
| class AudioDataset(Dataset): | |
| def __init__( | |
| self, | |
| metadata_root: str = "/mnt/bn/lqhaoheliu/metadata/processed/dataset_root.json", | |
| dataset_name: list = ["audioset"], | |
| split: str = "train", | |
| include_caption: bool = True, | |
| enable_mixup: bool = False, | |
| audio_processor: NaiveAudioProcessor = NaiveAudioProcessor(), | |
| ): | |
| """ | |
| Dataset that manages audio recordings. | |
| :param audio_conf: Dictionary containing the audio loading and preprocessing settings | |
| :param dataset_json_file | |
| """ | |
| self.metadata_root = load_json(metadata_root) | |
| self.dataset_name = dataset_name | |
| self.split = split | |
| self.include_caption = include_caption | |
| self.audio_processor = audio_processor | |
| self.enable_mixup = enable_mixup | |
| self.mixture_caption_template = "{} | {}" | |
| if self.enable_mixup: | |
| print( | |
| f"Template for the caption of mixture is: {self.mixture_caption_template}" | |
| ) | |
| self.build_dataset() | |
| print("Dataset initialization finished.") | |
| def __getitem__(self, index): | |
| datum = self.data[index] | |
| fname = datum["wav"] # base name of the wav file | |
| mix_datum = {"wav": None} | |
| if self.enable_mixup: | |
| if random.random() > 0.5: | |
| mix_datum = self.data[random.randint(0, len(self.data) - 1)] | |
| fname += " " + mix_datum["wav"] | |
| data = {"fname": fname} | |
| if self.include_caption: | |
| caption = self.get_caption_from_datum( | |
| datum, | |
| mix_datum, | |
| template_description=self.mixture_caption_template, | |
| ) | |
| data.update({"caption": caption}) | |
| data.update(self.audio_processor(datum["wav"], mix_datum["wav"])) | |
| return data | |
| def text_to_filename(self, text): | |
| return text.replace(" ", "_").replace("'", "_").replace('"', "_") | |
| def get_dataset_root_path(self, dataset): | |
| assert dataset in self.metadata_root.keys() | |
| return self.metadata_root[dataset] | |
| def get_dataset_metadata_path(self, dataset, key): | |
| # key: train, test, val, class_label_indices | |
| try: | |
| if dataset in self.metadata_root["metadata"]["path"].keys(): | |
| return self.metadata_root["metadata"]["path"][dataset][key] | |
| except: | |
| raise ValueError( | |
| 'Dataset %s does not metadata "%s" specified' % (dataset, key) | |
| ) | |
| def __len__(self): | |
| return len(self.data) | |
| def _relative_path_to_absolute_path(self, metadata, dataset_name): | |
| root_path = self.get_dataset_root_path(dataset_name) | |
| for i in range(len(metadata["data"])): | |
| assert "wav" in metadata["data"][i].keys(), metadata["data"][i] | |
| assert metadata["data"][i]["wav"][0] != "/", ( | |
| "The dataset metadata should only contain relative path to the audio file: " | |
| + str(metadata["data"][i]["wav"]) | |
| ) | |
| metadata["data"][i]["wav"] = os.path.join( | |
| root_path, metadata["data"][i]["wav"] | |
| ) | |
| return metadata | |
| def build_dataset(self): | |
| self.data = [] | |
| print("Build dataset split %s from %s" % (self.split, self.dataset_name)) | |
| if type(self.dataset_name) is str: | |
| data_json = load_json( | |
| self.get_dataset_metadata_path(self.dataset_name, key=self.split) | |
| ) | |
| data_json = self._relative_path_to_absolute_path( | |
| data_json, self.dataset_name | |
| ) | |
| self.data = data_json["data"] | |
| elif type(self.dataset_name) is list: | |
| for dataset_name in self.dataset_name: | |
| data_json = load_json( | |
| self.get_dataset_metadata_path(dataset_name, key=self.split) | |
| ) | |
| data_json = self._relative_path_to_absolute_path( | |
| data_json, dataset_name | |
| ) | |
| self.data += data_json["data"] | |
| else: | |
| raise Exception("Invalid data format") | |
| print("Data size: {}".format(len(self.data))) | |
| def is_contain_caption(self, datum): | |
| if datum is not None: | |
| caption_keys = [x for x in datum.keys() if ("caption" in x)] | |
| return len(caption_keys) > 0 | |
| else: | |
| return False | |
| def _read_datum_caption(self, datum): | |
| if datum is not None: | |
| caption_keys = [x for x in datum.keys() if ("caption" in x)] | |
| random_index = torch.randint(0, len(caption_keys), (1,))[0].item() | |
| return datum[caption_keys[random_index]] | |
| else: | |
| return "" # NOTE: return empty string if datum is not provided | |
| def label_indices_to_text( | |
| self, | |
| datum, | |
| label_indices, | |
| template_description: str = "{}", # e.g., "This audio contains the sound of {}" | |
| ): | |
| if self.is_contain_caption(datum): | |
| return self._read_datum_caption(datum) | |
| elif "label" in datum.keys(): | |
| name_indices = torch.where(label_indices > 0.1)[0] | |
| labels = "" | |
| for id, each in enumerate(name_indices): | |
| if id == len(name_indices) - 1: | |
| labels += "%s." % self.num2label[int(each)] | |
| else: | |
| labels += "%s, " % self.num2label[int(each)] | |
| return template_description.format(labels) | |
| else: | |
| return "" # NOTE: return empty string if both label and caption are not provided | |
| def get_sample_text_caption(self, datum, mix_datum, label_indices): | |
| text = self.label_indices_to_text(datum, label_indices) | |
| if mix_datum is not None: | |
| text += " " + self.label_indices_to_text(mix_datum, label_indices) | |
| return text | |
| def get_caption_from_datum( | |
| self, datum, mix_datum=None, template_description="{} {}" | |
| ): | |
| caption = "" | |
| if self.is_contain_caption(datum): | |
| caption += self._read_datum_caption(datum) | |
| # Mixup the caption if `mix_datum` is not None | |
| if mix_datum is not None and self.is_contain_caption(mix_datum): | |
| mix_caption = self._read_datum_caption(mix_datum) | |
| caption = template_description.format(caption, mix_caption) | |
| return caption | |
| if __name__ == "__main__": | |
| import torch | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| dataset = AudioDataset( | |
| dataset_name=["audiocaps"], | |
| include_caption=True, | |
| enable_mixup=True, | |
| audio_processor=FbankAudioProcessor(), | |
| ) | |
| loader = DataLoader(dataset, batch_size=2, num_workers=0, shuffle=True) | |
| for cnt, each in tqdm(enumerate(loader)): | |
| # print(each["waveform"].size(), each["log_mel_spec"].size()) | |
| # print(each['freq_energy_percentile']) | |
| import ipdb | |
| ipdb.set_trace() | |