prithivMLmods commited on
Commit
ecce109
·
verified ·
1 Parent(s): 4f97d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -38
app.py CHANGED
@@ -42,23 +42,6 @@ h1 {
42
  }
43
  '''
44
 
45
- def progress_bar_html(label: str) -> str:
46
- """Return an HTML snippet with a label and an animated, thin light-blue progress bar."""
47
- return f"""
48
- <div style="display: flex; align-items: center;">
49
- <span style="margin-right: 8px;">{label}</span>
50
- <div style="position: relative; width: 110px; height: 5px; background: #e0e0e0; border-radius: 5px; overflow: hidden;">
51
- <div style="width: 100%; height: 100%; background-color: lightblue; animation: progress-bar-animation 1s linear infinite;"></div>
52
- </div>
53
- </div>
54
- <style>
55
- @keyframes progress-bar-animation {{
56
- 0% {{ transform: translateX(-100%); }}
57
- 100% {{ transform: translateX(100%); }}
58
- }}
59
- </style>
60
- """
61
-
62
  MAX_MAX_NEW_TOKENS = 2048
63
  DEFAULT_MAX_NEW_TOKENS = 1024
64
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -105,6 +88,23 @@ def clean_chat_history(chat_history):
105
  cleaned.append(msg)
106
  return cleaned
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Environment variables and parameters for Stable Diffusion XL
109
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
110
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -214,15 +214,13 @@ def generate(
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
217
- tts_prefix = "@tts"
218
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
219
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
220
-
221
  if text.strip().lower().startswith("@image"):
222
  # Remove the "@image" tag and use the rest as prompt
223
  prompt = text[len("@image"):].strip()
224
- # Yield progress bar for image generation
225
- yield progress_bar_html("Generating Image")
 
226
  image_paths, used_seed = generate_image_fn(
227
  prompt=prompt,
228
  negative_prompt="",
@@ -236,10 +234,15 @@ def generate(
236
  use_resolution_binning=True,
237
  num_images=1,
238
  )
239
- # Yield the generated image, replacing the progress bar
 
240
  yield gr.Image(image_paths[0])
241
  return # Exit early
242
 
 
 
 
 
243
  if is_tts and voice_index:
244
  voice = TTS_VOICES[voice_index - 1]
245
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
@@ -252,6 +255,7 @@ def generate(
252
  conversation = clean_chat_history(chat_history)
253
  conversation.append({"role": "user", "content": text})
254
 
 
255
  if files:
256
  if len(files) > 1:
257
  images = [load_image(image) for image in files]
@@ -274,17 +278,18 @@ def generate(
274
  thread.start()
275
 
276
  buffer = ""
277
- # Yield initial progress bar for multimodal generation
278
- yield progress_bar_html("Thinking...")
 
279
  for new_text in streamer:
280
  buffer += new_text
281
  buffer = buffer.replace("<|im_end|>", "")
282
  time.sleep(0.01)
283
- # Update with partial text and progress bar
284
- yield f"<div>{buffer}</div><div>{progress_bar_html('Thinking...')}</div>"
285
- # Final output: remove progress bar
286
- yield f"<div>{buffer}</div>"
287
  else:
 
288
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
289
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
290
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -305,18 +310,16 @@ def generate(
305
  t = Thread(target=model.generate, kwargs=generation_kwargs)
306
  t.start()
307
 
308
- # Yield initial progress bar for text generation
309
- yield progress_bar_html("Thinking...")
310
  outputs = []
 
 
 
311
  for new_text in streamer:
312
  outputs.append(new_text)
313
- current_text = "".join(outputs)
314
- time.sleep(0.01)
315
- # Update message with partial text and progress bar
316
- yield f"<div>{current_text}</div><div>{progress_bar_html('Thinking...')}</div>"
317
  final_response = "".join(outputs)
318
- # Final output: only the final response text, progress bar removed.
319
- yield f"<div>{final_response}</div>"
 
320
 
321
  # If TTS was requested, convert the final response to speech.
322
  if is_tts and voice:
 
42
  }
43
  '''
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  MAX_MAX_NEW_TOKENS = 2048
46
  DEFAULT_MAX_NEW_TOKENS = 1024
47
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
+ # Helper: returns HTML code for a thin light-green animated progress bar with a label.
92
+ def progress_bar_html(label: str) -> str:
93
+ return f'''
94
+ <div style="display: flex; align-items: center;">
95
+ <span>{label}</span>
96
+ <div style="flex-grow: 1; margin-left: 8px; height: 5px; background-color: lightgreen; overflow: hidden; position: relative;">
97
+ <div style="width: 100%; height: 100%; background: linear-gradient(90deg, rgba(255,255,255,0) 0%, rgba(255,255,255,0.5) 50%, rgba(255,255,255,0) 100%); animation: progressAnim 1s linear infinite;"></div>
98
+ </div>
99
+ </div>
100
+ <style>
101
+ @keyframes progressAnim {{
102
+ 0% {{ transform: translateX(-100%); }}
103
+ 100% {{ transform: translateX(100%); }}
104
+ }}
105
+ </style>
106
+ '''
107
+
108
  # Environment variables and parameters for Stable Diffusion XL
109
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
110
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
217
+ # For image generation triggered by "@image"
 
 
 
218
  if text.strip().lower().startswith("@image"):
219
  # Remove the "@image" tag and use the rest as prompt
220
  prompt = text[len("@image"):].strip()
221
+ # Yield a progress bar with label "Generating Image"
222
+ progress_component = gr.HTML(progress_bar_html("Generating Image"))
223
+ yield progress_component
224
  image_paths, used_seed = generate_image_fn(
225
  prompt=prompt,
226
  negative_prompt="",
 
234
  use_resolution_binning=True,
235
  num_images=1,
236
  )
237
+ # Clear the progress bar (replace with empty HTML) and then yield the image
238
+ yield gr.HTML.update(value="")
239
  yield gr.Image(image_paths[0])
240
  return # Exit early
241
 
242
+ tts_prefix = "@tts"
243
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
244
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
245
+
246
  if is_tts and voice_index:
247
  voice = TTS_VOICES[voice_index - 1]
248
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
255
  conversation = clean_chat_history(chat_history)
256
  conversation.append({"role": "user", "content": text})
257
 
258
+ # If there are attached image files, use multimodal processing
259
  if files:
260
  if len(files) > 1:
261
  images = [load_image(image) for image in files]
 
278
  thread.start()
279
 
280
  buffer = ""
281
+ # Yield a progress bar with label "Thinking..."
282
+ progress_component = gr.HTML(progress_bar_html("Thinking..."))
283
+ yield progress_component
284
  for new_text in streamer:
285
  buffer += new_text
286
  buffer = buffer.replace("<|im_end|>", "")
287
  time.sleep(0.01)
288
+ # Clear the progress bar and yield the final result text.
289
+ yield gr.HTML.update(value="")
290
+ yield buffer
 
291
  else:
292
+ # For pure text responses:
293
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
294
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
295
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
310
  t = Thread(target=model.generate, kwargs=generation_kwargs)
311
  t.start()
312
 
 
 
313
  outputs = []
314
+ # Yield a progress bar with label "Thinking..."
315
+ progress_component = gr.HTML(progress_bar_html("Thinking..."))
316
+ yield progress_component
317
  for new_text in streamer:
318
  outputs.append(new_text)
 
 
 
 
319
  final_response = "".join(outputs)
320
+ # Clear the progress bar and yield the final plain text result.
321
+ yield gr.HTML.update(value="")
322
+ yield final_response
323
 
324
  # If TTS was requested, convert the final response to speech.
325
  if is_tts and voice: