File size: 5,998 Bytes
1998a68
 
2760947
ace23de
 
1998a68
 
2760947
1998a68
 
 
 
 
 
 
 
 
 
 
 
 
2760947
 
 
 
 
 
 
 
 
1998a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9deebc2
1998a68
 
9deebc2
1998a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace23de
1998a68
 
 
 
ace23de
1998a68
ace23de
 
 
 
 
 
 
1998a68
ace23de
1998a68
 
9deebc2
 
1998a68
 
 
 
 
9deebc2
1998a68
 
 
 
 
 
 
 
 
 
 
 
 
 
2760947
 
1998a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9deebc2
 
 
 
 
1998a68
 
 
 
da6984c
 
1998a68
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import gradio as gr
import os
import soundfile as sf
import numpy as np
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from huggingface_hub import login

# Global model variables
model = None
model_config = None
device = None

def load_model():
    """Load the pretrained model on startup"""
    global model, model_config, device
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading model on device: {device}")
    
    # Check for HF_TOKEN environment variable (set in Space settings)
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        print("Using HF_TOKEN for authentication")
        login(token=hf_token)
    else:
        print("Warning: HF_TOKEN not found. Model access may fail if authentication is required.")
        print("Please set HF_TOKEN as a secret in your Space settings.")
    
    # Download and load the pretrained model
    model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
    sample_rate = model_config["sample_rate"]
    sample_size = model_config["sample_size"]
    
    model = model.to(device).eval().requires_grad_(False)
    model = model.to(torch.float16)  # Use half precision for efficiency
    
    print(f"Model loaded successfully. Sample rate: {sample_rate}, Sample size: {sample_size}")
    return model, model_config

def generate_audio(prompt, seconds_total=11):
    """Generate 4 audio variations from a text prompt"""
    global model, model_config, device
    
    if model is None:
        return None, None, None, None, "Model not loaded. Please wait..."
    
    if not prompt or not prompt.strip():
        return None, None, None, None, "Please enter a text prompt."
    
    # Set up text and timing conditioning (repeat for batch_size)
    conditioning = [{
        "prompt": prompt,
        "seconds_total": seconds_total
    }] * 4  # Repeat for batch_size=4
    
    # Generate 4 variations using batch_size=4
    try:
        output = generate_diffusion_cond(
            model,
            steps=8,
            cfg_scale=1.0,
            conditioning=conditioning,
            sample_size=model_config["sample_size"],
            sampler_type="pingpong",
            device=device,
            batch_size=4  # Generate 4 variations
        )
        
        sample_rate = model_config["sample_rate"]
        audio_files = []
        
        # Process each variation in the batch
        for i in range(4):
            # Extract single variation: [channels, samples]
            audio = output[i]  # Shape: [channels, samples]
            
            # Peak normalize, clip, convert to float32 numpy array
            audio = audio.to(torch.float32)
            audio_max = torch.max(torch.abs(audio))
            if audio_max > 0:
                audio = audio.div(audio_max)
            audio = audio.clamp(-1, 1).cpu().numpy()
            
            # Transpose to [samples, channels] for soundfile
            if audio.ndim == 1:
                audio = audio.reshape(-1, 1)
            else:
                audio = audio.T  # [channels, samples] -> [samples, channels]
            
            # Save to temporary file using soundfile
            filename = f"output_variation_{i+1}.wav"
            sf.write(filename, audio, sample_rate)
            audio_files.append(filename)
        
        # Return 4 separate audio files and status message
        return audio_files[0], audio_files[1], audio_files[2], audio_files[3], f"Generated 4 variations for: '{prompt}'"
    
    except Exception as e:
        import traceback
        error_msg = f"Error generating audio: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return None, None, None, None, error_msg

# Load model on startup
print("Initializing model...")
load_model()

# Create Gradio interface
with gr.Blocks(title="Stable Audio Open Small - 4 Variations") as demo:
    gr.Markdown("""
    # Stable Audio Open Small
    
    Generate up to 4 audio variations from a text prompt.
    
    **Model**: [stabilityai/stable-audio-open-small](https://huggingface.co/stabilityai/stable-audio-open-small)
    
    **Note**: This model requires accepting the license agreement. Make sure to set `HF_TOKEN` as a secret in your Space settings.
    
    Enter a text description and click Generate to create 4 different audio variations.
    """)
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Text Prompt",
                placeholder="e.g., 128 BPM tech house drum loop",
                lines=2
            )
            seconds_input = gr.Slider(
                minimum=1,
                maximum=11,
                value=11,
                step=1,
                label="Duration (seconds)",
                info="Maximum 11 seconds"
            )
            generate_btn = gr.Button("Generate", variant="primary")
        
        with gr.Column():
            status_output = gr.Textbox(label="Status", interactive=False)
            gr.Markdown("### Generated Audio Variations")
            audio_output_1 = gr.Audio(label="Variation 1", interactive=False)
            audio_output_2 = gr.Audio(label="Variation 2", interactive=False)
            audio_output_3 = gr.Audio(label="Variation 3", interactive=False)
            audio_output_4 = gr.Audio(label="Variation 4", interactive=False)
    
    generate_btn.click(
        fn=generate_audio,
        inputs=[prompt_input, seconds_input],
        outputs=[audio_output_1, audio_output_2, audio_output_3, audio_output_4, status_output],
        api_name="generate_audio"
    )
    
    gr.Markdown("""
    ### Tips
    - The model works best with English descriptions
    - Better at generating sound effects and field recordings than music
    - Each variation uses a different random seed for diversity
    """)

if __name__ == "__main__":
    demo.launch()