Chrishugs commited on
Commit
164e442
Β·
verified Β·
1 Parent(s): d740087

Upload 10 files

Browse files
Files changed (2) hide show
  1. app.py +192 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+ import logging
7
+ from typing import Optional, Tuple
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Global model variable
14
+ model = None
15
+
16
+ def load_dia_model():
17
+ """Load the Dia model"""
18
+ global model
19
+ try:
20
+ logger.info("Loading Dia model...")
21
+ from dia import Dia
22
+
23
+ # Load with appropriate device and dtype
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ compute_dtype = "float16" if torch.cuda.is_available() else "float32"
26
+
27
+ model = Dia.from_pretrained(
28
+ "nari-labs/Dia-1.6B-0626",
29
+ device=device,
30
+ compute_dtype=compute_dtype
31
+ )
32
+ logger.info(f"Dia model loaded successfully on {device}")
33
+ return True
34
+ except Exception as e:
35
+ logger.error(f"Failed to load Dia model: {e}")
36
+ return False
37
+
38
+ def generate_speech(
39
+ text: str,
40
+ max_tokens: int = 3072,
41
+ temperature: float = 0.7,
42
+ top_p: float = 0.9
43
+ ) -> Tuple[Optional[str], str]:
44
+ """Generate speech from text using Dia model"""
45
+
46
+ if not text or not text.strip():
47
+ return None, "❌ Please enter some text to convert to speech"
48
+
49
+ if model is None:
50
+ return None, "❌ Model not loaded. Please refresh the page and try again."
51
+
52
+ try:
53
+ logger.info(f"Generating speech for text: {text[:50]}...")
54
+
55
+ # Generate audio using Dia model
56
+ audio_array = model.generate(
57
+ text=text.strip(),
58
+ max_tokens=max_tokens,
59
+ temperature=temperature,
60
+ top_p=top_p
61
+ )
62
+
63
+ # Save to temporary file
64
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
65
+ model.save_audio(temp_file.name, audio_array)
66
+
67
+ logger.info("Speech generation completed successfully")
68
+ return temp_file.name, f"βœ… Generated speech for: '{text[:50]}{'...' if len(text) > 50 else ''}'"
69
+
70
+ except Exception as e:
71
+ error_msg = f"❌ Error generating speech: {str(e)}"
72
+ logger.error(error_msg)
73
+ return None, error_msg
74
+
75
+ # Load model on startup
76
+ model_loaded = load_dia_model()
77
+
78
+ # Create Gradio interface
79
+ with gr.Blocks(
80
+ title="Dia TTS - Nari Voice Generator",
81
+ theme=gr.themes.Soft(),
82
+ css="""
83
+ .gradio-container {
84
+ max-width: 800px !important;
85
+ margin: auto !important;
86
+ }
87
+ """
88
+ ) as demo:
89
+
90
+ gr.Markdown("""
91
+ # πŸŽ™οΈ Dia TTS - Nari Voice Generator
92
+
93
+ Convert your text into natural, human-like speech using the advanced Dia text-to-speech model.
94
+
95
+ **Model**: `nari-labs/Dia-1.6B-0626`
96
+ """)
97
+
98
+ if not model_loaded:
99
+ gr.Markdown("⚠️ **Warning**: Model failed to load. Some functionality may not work.")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ text_input = gr.Textbox(
104
+ label="πŸ“ Text Input",
105
+ placeholder="Enter the text you want to convert to speech...",
106
+ lines=4,
107
+ max_lines=10
108
+ )
109
+
110
+ with gr.Row():
111
+ max_tokens = gr.Slider(
112
+ minimum=512,
113
+ maximum=4096,
114
+ value=3072,
115
+ step=128,
116
+ label="🎯 Max Tokens"
117
+ )
118
+ temperature = gr.Slider(
119
+ minimum=0.1,
120
+ maximum=1.0,
121
+ value=0.7,
122
+ step=0.1,
123
+ label="🌑️ Temperature"
124
+ )
125
+ top_p = gr.Slider(
126
+ minimum=0.1,
127
+ maximum=1.0,
128
+ value=0.9,
129
+ step=0.1,
130
+ label="🎲 Top P"
131
+ )
132
+
133
+ generate_btn = gr.Button(
134
+ "🎡 Generate Speech",
135
+ variant="primary",
136
+ size="lg"
137
+ )
138
+
139
+ with gr.Column():
140
+ audio_output = gr.Audio(
141
+ label="πŸ”Š Generated Speech",
142
+ type="filepath"
143
+ )
144
+ status_output = gr.Textbox(
145
+ label="πŸ“Š Status",
146
+ interactive=False,
147
+ lines=2
148
+ )
149
+
150
+ # Event handlers
151
+ generate_btn.click(
152
+ fn=generate_speech,
153
+ inputs=[text_input, max_tokens, temperature, top_p],
154
+ outputs=[audio_output, status_output],
155
+ show_progress=True
156
+ )
157
+
158
+ # Examples
159
+ gr.Examples(
160
+ examples=[
161
+ ["Transform your text into natural, human-like speech with our advanced AI technology.", 3072, 0.7, 0.9],
162
+ ["The quick brown fox jumps over the lazy dog. This is a test of the Dia text-to-speech system.", 2048, 0.8, 0.9],
163
+ ["Welcome to the future of voice synthesis. Experience the power of AI-generated speech.", 3072, 0.6, 0.8],
164
+ ],
165
+ inputs=[text_input, max_tokens, temperature, top_p],
166
+ outputs=[audio_output, status_output],
167
+ fn=generate_speech,
168
+ cache_examples=False
169
+ )
170
+
171
+ gr.Markdown("""
172
+ ---
173
+
174
+ ### πŸ“š Usage Tips:
175
+ - **Max Tokens**: Controls the length of generated audio (higher = longer)
176
+ - **Temperature**: Controls randomness (0.1 = conservative, 1.0 = creative)
177
+ - **Top P**: Controls diversity of word selection (0.1 = focused, 1.0 = diverse)
178
+
179
+ ### βš™οΈ Technical Details:
180
+ - Model: Dia-1.6B-0626 by Nari Labs
181
+ - Output Format: WAV audio
182
+ - Recommended Text Length: 50-500 characters for best results
183
+ """)
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch(
187
+ server_name="0.0.0.0",
188
+ server_port=7860,
189
+ share=False,
190
+ show_error=True,
191
+ quiet=False
192
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ numpy>=1.21.0
4
+ gradio>=4.0.0
5
+ huggingface-hub>=0.16.0
6
+ dac>=1.0.0
7
+ pydantic>=2.0.0