sagar007 commited on
Commit
1b8f6f0
Β·
verified Β·
1 Parent(s): d3fde93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -23
app.py CHANGED
@@ -5,7 +5,9 @@ import gradio as gr
5
  from threading import Thread
6
  from PIL import Image
7
  import subprocess
8
- import spaces # Add this import
 
 
9
 
10
  # Install flash-attention
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -96,6 +98,24 @@ def process_vision_query(image, text_input):
96
  generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
97
  response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
98
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # Custom CSS
101
  custom_css = """
@@ -134,8 +154,8 @@ custom_suggestions = """
134
  <p>Analyze Images with Vision Model</p>
135
  </div>
136
  <div class="suggestion">
137
- <span class="suggestion-icon">πŸ€–</span>
138
- <p>Get AI-generated responses</p>
139
  </div>
140
  <div class="suggestion">
141
  <span class="suggestion-icon">πŸ”</span>
@@ -158,33 +178,23 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
158
  gr.HTML(custom_suggestions)
159
 
160
  with gr.Tab("Text Model (Phi-3.5-mini)"):
161
- chatbot = gr.Chatbot(height=400)
162
- msg = gr.Textbox(label="Message", placeholder="Type your message here...")
163
- with gr.Accordion("Advanced Options", open=False):
164
- system_prompt = gr.Textbox(value="You are a helpful assistant", label="System Prompt")
165
- temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature")
166
- max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens")
167
- top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p")
168
- top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k")
169
-
170
- submit_btn = gr.Button("Submit", variant="primary")
171
- clear_btn = gr.Button("Clear Chat", variant="secondary")
172
-
173
- submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot])
174
- clear_btn.click(lambda: None, None, chatbot, queue=False)
175
 
176
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
 
 
 
177
  with gr.Row():
178
  with gr.Column(scale=1):
179
- vision_input_img = gr.Image(label="Upload an Image", type="pil")
180
- vision_text_input = gr.Textbox(label="Ask a question about the image", placeholder="What do you see in this image?")
181
- vision_submit_btn = gr.Button("Analyze Image", variant="primary")
182
  with gr.Column(scale=1):
183
- vision_output_text = gr.Textbox(label="AI Analysis", lines=10)
184
 
185
- vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
186
 
187
- gr.HTML("<footer>Powered by Phi 3.5 Multimodal AI</footer>")
188
 
189
  if __name__ == "__main__":
190
  demo.launch()
 
5
  from threading import Thread
6
  from PIL import Image
7
  import subprocess
8
+ import spaces
9
+ from parler_tts import ParlerTTSForConditionalGeneration
10
+ import soundfile as sf
11
 
12
  # Install flash-attention
13
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
98
  generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
99
  response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
100
  return response
101
+
102
+ # Load Parler-TTS model
103
+ tts_device = "cuda:0" if torch.cuda.is_available() else "cpu"
104
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1").to(tts_device)
105
+ tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
106
+
107
+ @spaces.GPU
108
+ def generate_speech(prompt, description):
109
+ input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to(tts_device)
110
+ prompt_input_ids = tts_tokenizer(prompt, return_tensors="pt").input_ids.to(tts_device)
111
+
112
+ generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
113
+ audio_arr = generation.cpu().numpy().squeeze()
114
+
115
+ output_path = "output_audio.wav"
116
+ sf.write(output_path, audio_arr, tts_model.config.sampling_rate)
117
+
118
+ return output_path
119
 
120
  # Custom CSS
121
  custom_css = """
 
154
  <p>Analyze Images with Vision Model</p>
155
  </div>
156
  <div class="suggestion">
157
+ <span class="suggestion-icon">πŸ”Š</span>
158
+ <p>Generate Speech with Parler-TTS</p>
159
  </div>
160
  <div class="suggestion">
161
  <span class="suggestion-icon">πŸ”</span>
 
178
  gr.HTML(custom_suggestions)
179
 
180
  with gr.Tab("Text Model (Phi-3.5-mini)"):
181
+ # ... (previous text model code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
184
+ # ... (previous vision model code remains the same)
185
+
186
+ with gr.Tab("Text-to-Speech (Parler-TTS)"):
187
  with gr.Row():
188
  with gr.Column(scale=1):
189
+ tts_prompt = gr.Textbox(label="Text to Speak", placeholder="Enter the text you want to convert to speech...")
190
+ tts_description = gr.Textbox(label="Voice Description", value="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.", lines=3)
191
+ tts_submit_btn = gr.Button("Generate Speech", variant="primary")
192
  with gr.Column(scale=1):
193
+ tts_output_audio = gr.Audio(label="Generated Speech")
194
 
195
+ tts_submit_btn.click(generate_speech, inputs=[tts_prompt, tts_description], outputs=[tts_output_audio])
196
 
197
+ gr.HTML("<footer>Powered by Phi 3.5 Multimodal AI and Parler-TTS</footer>")
198
 
199
  if __name__ == "__main__":
200
  demo.launch()