mariam-ahmed15 commited on
Commit
b7e88e7
·
verified ·
1 Parent(s): 85d448b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import librosa
4
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
5
+
6
+ # 1. CONFIGURATION
7
+ MODEL_ID = "facebook/wav2vec2-xls-r-300m"
8
+ QUANTIZED_MODEL_PATH = "quantized_model.pth"
9
+
10
+ # 2. LOAD MODEL
11
+ print("Loading model architecture...")
12
+ # A. Load the skeleton (empty weights)
13
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, num_labels=2)
14
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
15
+
16
+ # B. Apply the quantization structure (Must happen BEFORE loading weights)
17
+ # This changes the Linear layers to INT8 format so the keys match
18
+ model = torch.quantization.quantize_dynamic(
19
+ model, {torch.nn.Linear}, dtype=torch.qint8
20
+ )
21
+
22
+ # C. Load your trained quantized weights
23
+ print("Loading quantized weights...")
24
+ model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH, map_location=torch.device('cpu')))
25
+ model.eval()
26
+
27
+ # 3. DEFINE PREDICTION FUNCTION
28
+ def predict_audio(audio_path):
29
+ if audio_path is None:
30
+ return "No Audio Provided"
31
+
32
+ # Load and resample audio to 16kHz
33
+ speech_array, sr = librosa.load(audio_path, sr=16000)
34
+
35
+ # Process inputs
36
+ inputs = feature_extractor(
37
+ speech_array,
38
+ sampling_rate=16000,
39
+ return_tensors="pt",
40
+ padding=True
41
+ )
42
+
43
+ with torch.no_grad():
44
+ logits = model(**inputs).logits
45
+
46
+ # Convert logits to probabilities
47
+ probs = torch.nn.functional.softmax(logits, dim=-1)
48
+
49
+ # Assuming Label 0 = Real, Label 1 = Deepfake (Adjust based on your training!)
50
+ fake_prob = probs[0][1].item()
51
+ real_prob = probs[0][0].item()
52
+
53
+ return {
54
+ "Deepfake": fake_prob,
55
+ "Real": real_prob
56
+ }
57
+
58
+ # 4. CREATE API INTERFACE
59
+ # This creates a visual UI *and* a hidden API endpoint
60
+ iface = gr.Interface(
61
+ fn=predict_audio,
62
+ inputs=gr.Audio(type="filepath"),
63
+ outputs=gr.Label(num_top_classes=2),
64
+ title="Deepfake Audio Detection API",
65
+ description="Upload an audio file to check if it's real or fake."
66
+ )
67
+
68
+ iface.launch()