Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -50,11 +50,11 @@ vision_model = AutoModelForCausalLM.from_pretrained(
|
|
50 |
|
51 |
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
|
52 |
|
53 |
-
# Helper functions
|
54 |
# Initialize Parler-TTS
|
55 |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
|
56 |
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
|
57 |
|
|
|
58 |
# Helper functions
|
59 |
@spaces.GPU
|
60 |
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
|
@@ -67,10 +67,12 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
|
|
67 |
conversation.append({"role": "user", "content": message})
|
68 |
|
69 |
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
|
|
|
70 |
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
71 |
|
72 |
generate_kwargs = dict(
|
73 |
input_ids=input_ids,
|
|
|
74 |
max_new_tokens=max_new_tokens,
|
75 |
do_sample=temperature > 0,
|
76 |
top_p=top_p,
|
@@ -85,7 +87,7 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
|
|
85 |
thread.start()
|
86 |
|
87 |
buffer = ""
|
88 |
-
|
89 |
for new_text in streamer:
|
90 |
buffer += new_text
|
91 |
|
@@ -97,18 +99,10 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
|
|
97 |
with torch.no_grad():
|
98 |
audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
|
99 |
|
100 |
-
|
|
|
101 |
|
102 |
-
|
103 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
|
104 |
-
sf.write(temp_audio.name, audio_arr, tts_model.config.sampling_rate)
|
105 |
-
audio_files.append(temp_audio.name)
|
106 |
-
|
107 |
-
yield history + [[message, buffer]], audio_files
|
108 |
-
|
109 |
-
# Clean up temporary audio files
|
110 |
-
for audio_file in audio_files:
|
111 |
-
os.remove(audio_file)
|
112 |
|
113 |
@spaces.GPU
|
114 |
def process_vision_query(image, text_input):
|
@@ -212,7 +206,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
|
|
212 |
|
213 |
submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
|
214 |
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
215 |
-
|
216 |
with gr.Tab("Vision Model (Phi-3.5-vision)"):
|
217 |
with gr.Row():
|
218 |
with gr.Column(scale=1):
|
|
|
50 |
|
51 |
vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
|
52 |
|
|
|
53 |
# Initialize Parler-TTS
|
54 |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
|
55 |
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
|
56 |
|
57 |
+
# Helper functions
|
58 |
# Helper functions
|
59 |
@spaces.GPU
|
60 |
def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
|
|
|
67 |
conversation.append({"role": "user", "content": message})
|
68 |
|
69 |
input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
|
70 |
+
attention_mask = torch.ones_like(input_ids) # Create attention mask
|
71 |
streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
72 |
|
73 |
generate_kwargs = dict(
|
74 |
input_ids=input_ids,
|
75 |
+
attention_mask=attention_mask, # Pass attention mask
|
76 |
max_new_tokens=max_new_tokens,
|
77 |
do_sample=temperature > 0,
|
78 |
top_p=top_p,
|
|
|
87 |
thread.start()
|
88 |
|
89 |
buffer = ""
|
90 |
+
audio_buffer = np.array([])
|
91 |
for new_text in streamer:
|
92 |
buffer += new_text
|
93 |
|
|
|
99 |
with torch.no_grad():
|
100 |
audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
|
101 |
|
102 |
+
new_audio = audio_generation.cpu().numpy().squeeze()
|
103 |
+
audio_buffer = np.concatenate((audio_buffer, new_audio))
|
104 |
|
105 |
+
yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
@spaces.GPU
|
108 |
def process_vision_query(image, text_input):
|
|
|
206 |
|
207 |
submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
|
208 |
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
|
|
209 |
with gr.Tab("Vision Model (Phi-3.5-vision)"):
|
210 |
with gr.Row():
|
211 |
with gr.Column(scale=1):
|