Gijs Wijngaard commited on
Commit
880c54c
·
1 Parent(s): 7800b36
Files changed (2) hide show
  1. app.py +83 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
+
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model_name = "ibm-granite/granite-speech-3.3-8b"
10
+
11
+ processor = AutoProcessor.from_pretrained(model_name)
12
+ tokenizer = processor.tokenizer
13
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
14
+ model_name, device_map=device, torch_dtype=torch.bfloat16
15
+ )
16
+
17
+
18
+ def _load_audio_mono_16k(file_path: str) -> torch.Tensor:
19
+ wav, sr = torchaudio.load(file_path, normalize=True)
20
+ if wav.shape[0] > 1:
21
+ wav = torch.mean(wav, dim=0, keepdim=True)
22
+ if sr != 16000:
23
+ wav = torchaudio.functional.resample(wav, sr, 16000)
24
+ return wav
25
+
26
+
27
+ def process_audio(audio_path: str, instruction: str, max_tokens: int = 200) -> str:
28
+ if not audio_path:
29
+ return "Please upload an audio file."
30
+
31
+ wav = _load_audio_mono_16k(audio_path)
32
+
33
+ date_string = datetime.now().strftime("%B %d, %Y")
34
+
35
+ system_prompt = (
36
+ "Knowledge Cutoff Date: April 2024.\n"
37
+ f"Today's Date: {date_string}.\n"
38
+ "You are Granite, developed by IBM. You are a helpful AI assistant"
39
+ )
40
+ user_prompt = f"<|audio|>{instruction.strip()}"
41
+ chat = [
42
+ {"role": "system", "content": system_prompt},
43
+ {"role": "user", "content": user_prompt},
44
+ ]
45
+ prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
46
+
47
+ model_inputs = processor(prompt, wav, device=device, return_tensors="pt").to(device)
48
+ outputs = model.generate(
49
+ **model_inputs,
50
+ max_new_tokens=int(max_tokens),
51
+ do_sample=False,
52
+ num_beams=1,
53
+ )
54
+
55
+ num_input_tokens = model_inputs["input_ids"].shape[-1]
56
+ new_tokens = torch.unsqueeze(outputs[0, num_input_tokens:], dim=0)
57
+ text = tokenizer.batch_decode(new_tokens, add_special_tokens=False, skip_special_tokens=True)[0]
58
+ return text
59
+
60
+
61
+ with gr.Blocks(title="Granite Speech Demo") as demo:
62
+ gr.Markdown("# Granite Speech-to-Text Demo")
63
+ gr.Markdown("Upload audio and transcribe with IBM Granite.")
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
68
+ instruction = gr.Textbox(
69
+ label="Instruction",
70
+ value="can you transcribe the speech into a written format?",
71
+ )
72
+ max_tokens = gr.Slider(50, 1000, value=200, step=50, label="Max Output Tokens")
73
+ submit_btn = gr.Button("Transcribe", variant="primary")
74
+ with gr.Column():
75
+ output_text = gr.Textbox(label="Output", lines=12)
76
+
77
+ submit_btn.click(process_audio, [audio_input, instruction, max_tokens], output_text)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ demo.queue().launch(share=False, ssr_mode=False)
82
+
83
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.1.0
3
+ torchaudio>=2.1.0
4
+ transformers>=4.43.0
5
+ huggingface_hub>=0.23.0
6
+ accelerate>=0.30.0
7
+