sagar007 commited on
Commit
9f22f0a
·
verified ·
1 Parent(s): 8ce99fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -50,11 +50,11 @@ vision_model = AutoModelForCausalLM.from_pretrained(
50
 
51
  vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
52
 
53
- # Helper functions
54
  # Initialize Parler-TTS
55
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
56
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
57
 
 
58
  # Helper functions
59
  @spaces.GPU
60
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
@@ -67,10 +67,12 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
67
  conversation.append({"role": "user", "content": message})
68
 
69
  input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
 
70
  streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
71
 
72
  generate_kwargs = dict(
73
  input_ids=input_ids,
 
74
  max_new_tokens=max_new_tokens,
75
  do_sample=temperature > 0,
76
  top_p=top_p,
@@ -85,7 +87,7 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
85
  thread.start()
86
 
87
  buffer = ""
88
- audio_files = []
89
  for new_text in streamer:
90
  buffer += new_text
91
 
@@ -97,18 +99,10 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
97
  with torch.no_grad():
98
  audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
99
 
100
- audio_arr = audio_generation.cpu().numpy().squeeze()
 
101
 
102
- # Save the audio to a temporary file
103
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
104
- sf.write(temp_audio.name, audio_arr, tts_model.config.sampling_rate)
105
- audio_files.append(temp_audio.name)
106
-
107
- yield history + [[message, buffer]], audio_files
108
-
109
- # Clean up temporary audio files
110
- for audio_file in audio_files:
111
- os.remove(audio_file)
112
 
113
  @spaces.GPU
114
  def process_vision_query(image, text_input):
@@ -212,7 +206,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
212
 
213
  submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
214
  clear_btn.click(lambda: None, None, chatbot, queue=False)
215
-
216
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
217
  with gr.Row():
218
  with gr.Column(scale=1):
 
50
 
51
  vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
52
 
 
53
  # Initialize Parler-TTS
54
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
55
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
56
 
57
+ # Helper functions
58
  # Helper functions
59
  @spaces.GPU
60
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
 
67
  conversation.append({"role": "user", "content": message})
68
 
69
  input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
70
+ attention_mask = torch.ones_like(input_ids) # Create attention mask
71
  streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
72
 
73
  generate_kwargs = dict(
74
  input_ids=input_ids,
75
+ attention_mask=attention_mask, # Pass attention mask
76
  max_new_tokens=max_new_tokens,
77
  do_sample=temperature > 0,
78
  top_p=top_p,
 
87
  thread.start()
88
 
89
  buffer = ""
90
+ audio_buffer = np.array([])
91
  for new_text in streamer:
92
  buffer += new_text
93
 
 
99
  with torch.no_grad():
100
  audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
101
 
102
+ new_audio = audio_generation.cpu().numpy().squeeze()
103
+ audio_buffer = np.concatenate((audio_buffer, new_audio))
104
 
105
+ yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
 
 
 
 
 
 
 
 
 
106
 
107
  @spaces.GPU
108
  def process_vision_query(image, text_input):
 
206
 
207
  submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
208
  clear_btn.click(lambda: None, None, chatbot, queue=False)
 
209
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
210
  with gr.Row():
211
  with gr.Column(scale=1):