import torch import gradio as gr import os import soundfile as sf import numpy as np from stable_audio_tools import get_pretrained_model from stable_audio_tools.inference.generation import generate_diffusion_cond from huggingface_hub import login # Global model variables model = None model_config = None device = None def load_model(): """Load the pretrained model on startup""" global model, model_config, device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") # Check for HF_TOKEN environment variable (set in Space settings) hf_token = os.getenv("HF_TOKEN") if hf_token: print("Using HF_TOKEN for authentication") login(token=hf_token) else: print("Warning: HF_TOKEN not found. Model access may fail if authentication is required.") print("Please set HF_TOKEN as a secret in your Space settings.") # Download and load the pretrained model model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small") sample_rate = model_config["sample_rate"] sample_size = model_config["sample_size"] model = model.to(device).eval().requires_grad_(False) model = model.to(torch.float16) # Use half precision for efficiency print(f"Model loaded successfully. Sample rate: {sample_rate}, Sample size: {sample_size}") return model, model_config def generate_audio(prompt, seconds_total=11): """Generate 4 audio variations from a text prompt""" global model, model_config, device if model is None: return None, None, None, None, "Model not loaded. Please wait..." if not prompt or not prompt.strip(): return None, None, None, None, "Please enter a text prompt." # Set up text and timing conditioning (repeat for batch_size) conditioning = [{ "prompt": prompt, "seconds_total": seconds_total }] * 4 # Repeat for batch_size=4 # Generate 4 variations using batch_size=4 try: output = generate_diffusion_cond( model, steps=8, cfg_scale=1.0, conditioning=conditioning, sample_size=model_config["sample_size"], sampler_type="pingpong", device=device, batch_size=4 # Generate 4 variations ) sample_rate = model_config["sample_rate"] audio_files = [] # Process each variation in the batch for i in range(4): # Extract single variation: [channels, samples] audio = output[i] # Shape: [channels, samples] # Peak normalize, clip, convert to float32 numpy array audio = audio.to(torch.float32) audio_max = torch.max(torch.abs(audio)) if audio_max > 0: audio = audio.div(audio_max) audio = audio.clamp(-1, 1).cpu().numpy() # Transpose to [samples, channels] for soundfile if audio.ndim == 1: audio = audio.reshape(-1, 1) else: audio = audio.T # [channels, samples] -> [samples, channels] # Save to temporary file using soundfile filename = f"output_variation_{i+1}.wav" sf.write(filename, audio, sample_rate) audio_files.append(filename) # Return 4 separate audio files and status message return audio_files[0], audio_files[1], audio_files[2], audio_files[3], f"Generated 4 variations for: '{prompt}'" except Exception as e: import traceback error_msg = f"Error generating audio: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, None, None, None, error_msg # Load model on startup print("Initializing model...") load_model() # Create Gradio interface with gr.Blocks(title="Stable Audio Open Small - 4 Variations") as demo: gr.Markdown(""" # Stable Audio Open Small Generate up to 4 audio variations from a text prompt. **Model**: [stabilityai/stable-audio-open-small](https://huggingface.co/stabilityai/stable-audio-open-small) **Note**: This model requires accepting the license agreement. Make sure to set `HF_TOKEN` as a secret in your Space settings. Enter a text description and click Generate to create 4 different audio variations. """) with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Text Prompt", placeholder="e.g., 128 BPM tech house drum loop", lines=2 ) seconds_input = gr.Slider( minimum=1, maximum=11, value=11, step=1, label="Duration (seconds)", info="Maximum 11 seconds" ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(): status_output = gr.Textbox(label="Status", interactive=False) gr.Markdown("### Generated Audio Variations") audio_output_1 = gr.Audio(label="Variation 1", interactive=False) audio_output_2 = gr.Audio(label="Variation 2", interactive=False) audio_output_3 = gr.Audio(label="Variation 3", interactive=False) audio_output_4 = gr.Audio(label="Variation 4", interactive=False) generate_btn.click( fn=generate_audio, inputs=[prompt_input, seconds_input], outputs=[audio_output_1, audio_output_2, audio_output_3, audio_output_4, status_output], api_name="generate_audio" ) gr.Markdown(""" ### Tips - The model works best with English descriptions - Better at generating sound effects and field recordings than music - Each variation uses a different random seed for diversity """) if __name__ == "__main__": demo.launch()