Spaces:
Runtime error
Runtime error
| #importing all the necessary packages | |
| import torch | |
| import transformers | |
| import gradio as gr | |
| from torchaudio.sox_effects import apply_effects_file | |
| from termcolor import colored | |
| from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForAudioFrameClassification | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Defines the effects to apply to the audio file | |
| EFFECTS = [ | |
| ['remix', '-'], # merge all the channels | |
| ["channels", "1"], #channel-->mono | |
| ["rate", "16000"], # resample to 16000 Hz | |
| ["gain", "-1.0"], #Attenuation -1 dB | |
| ["silence", "1", "0.1", "0.1%", "-1", "0.1", "0.1%"], | |
| #['pad', '0', '1.5'], # add 1.5 seconds silence at the end | |
| ['trim', '0', '10'], # get the first 10 seconds | |
| ] | |
| THRESHOLD = 0.85 #depends on dataset | |
| model_name = "microsoft/unispeech-sat-base-sd" | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) | |
| model = UniSpeechSatForAudioFrameClassification.from_pretrained(model_name).to(device) | |
| def fn(path): | |
| #Applying the effects to the audio input file | |
| wav, _ = apply_effects_file(path, EFFECTS) | |
| #Extracting features | |
| input = feature_extractor(wav.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device) | |
| with torch.no_grad(): | |
| logits = model(input).logits | |
| logits = logits.to(device) | |
| probabilities = torch.sigmoid(logits[0]) | |
| # labels is a one-hot array of shape (num_frames, num_speakers) | |
| labels = (probabilities > 0.5).long() | |
| return labels | |
| inputs = [ | |
| gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"), | |
| ] | |
| output = gr.outputs.Textbox(label="Output Text") | |
| gr.Interface( | |
| fn=fn, | |
| inputs=inputs, | |
| outputs=output, | |
| theme = "grass", | |
| title="Speaker diarization using UniSpeech-SAT and X-Vectors").launch(enable_queue=True) | |