prithivMLmods commited on
Commit
af0738e
·
verified ·
1 Parent(s): 0b6db44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -65
app.py CHANGED
@@ -25,7 +25,6 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
26
  DESCRIPTION = """
27
  # QwQ Edge 💬
28
- **Note:** During image generation, a progress bar will appear both at the top of the interface and within the chat. For text generation, a loading animation will display until the response begins.
29
  """
30
 
31
  css = '''
@@ -40,34 +39,6 @@ h1 {
40
  background: #1565c0;
41
  border-radius: 100vh;
42
  }
43
-
44
- /* Custom styling for progress bars within chat */
45
- .progress-bar-container {
46
- width: 100%;
47
- margin-top: 5px;
48
- }
49
-
50
- .progress-bar {
51
- width: 100%;
52
- height: 4px;
53
- background-color: #e0e0e0;
54
- border-radius: 2px;
55
- }
56
-
57
- .progress-bar::-webkit-progress-bar {
58
- background-color: #e0e0e0;
59
- border-radius: 2px;
60
- }
61
-
62
- .progress-bar::-webkit-progress-value {
63
- background-color: #90ee90; /* Light green */
64
- border-radius: 2px;
65
- }
66
-
67
- .progress-bar::-moz-progress-bar {
68
- background-color: #90ee90; /* Light green */
69
- border-radius: 2px;
70
- }
71
  '''
72
 
73
  MAX_MAX_NEW_TOKENS = 2048
@@ -76,6 +47,23 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
76
 
77
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # Load text-only model and tokenizer
80
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
81
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -91,7 +79,7 @@ TTS_VOICES = [
91
  "en-US-GuyNeural", # @tts2
92
  ]
93
 
94
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
95
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
96
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
97
  MODEL_ID,
@@ -106,20 +94,24 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
106
  return output_file
107
 
108
  def clean_chat_history(chat_history):
109
- """Filter out non-string content to prevent concatenation errors"""
 
 
 
110
  cleaned = []
111
  for msg in chat_history:
112
  if isinstance(msg, dict) and isinstance(msg.get("content"), str):
113
  cleaned.append(msg)
114
  return cleaned
115
 
116
- # Stable Diffusion XL setup
117
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH")
118
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
119
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
120
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
121
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
122
 
 
123
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
124
  MODEL_ID_SD,
125
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -128,19 +120,22 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
128
  ).to(device)
129
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
130
 
 
131
  if torch.cuda.is_available():
132
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
133
 
 
134
  if USE_TORCH_COMPILE:
135
  sd_pipe.compile()
136
 
 
137
  if ENABLE_CPU_OFFLOAD:
138
  sd_pipe.enable_model_cpu_offload()
139
 
140
  MAX_SEED = np.iinfo(np.int32).max
141
 
142
  def save_image(img: Image.Image) -> str:
143
- """Save a PIL image with a unique filename and return the path"""
144
  unique_name = str(uuid.uuid4()) + ".png"
145
  img.save(unique_name)
146
  return unique_name
@@ -165,7 +160,7 @@ def generate_image_fn(
165
  num_images: int = 1,
166
  progress=gr.Progress(track_tqdm=True),
167
  ):
168
- """Generate images using the SDXL pipeline"""
169
  seed = int(randomize_seed_fn(seed, randomize_seed))
170
  generator = torch.Generator(device=device).manual_seed(seed)
171
 
@@ -183,11 +178,13 @@ def generate_image_fn(
183
  options["use_resolution_binning"] = True
184
 
185
  images = []
 
186
  for i in range(0, num_images, BATCH_SIZE):
187
  batch_options = options.copy()
188
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
189
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
190
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
191
  if device.type == "cuda":
192
  with torch.autocast("cuda", dtype=torch.float16):
193
  outputs = sd_pipe(**batch_options)
@@ -216,14 +213,11 @@ def generate(
216
  text = input_dict["text"]
217
  files = input_dict.get("files", [])
218
 
 
219
  if text.strip().lower().startswith("@image"):
220
  prompt = text[len("@image"):].strip()
221
- # Initial message with progress bar at 0%
222
- yield gr.HTML(
223
- '<div>Generating Image...</div>'
224
- '<progress class="progress-bar" value="0" max="100" '
225
- 'style="width:100%; height:4px; background-color:#e0e0e0;"></progress>'
226
- )
227
  image_paths, used_seed = generate_image_fn(
228
  prompt=prompt,
229
  negative_prompt="",
@@ -237,9 +231,9 @@ def generate(
237
  use_resolution_binning=True,
238
  num_images=1,
239
  )
240
- # Final message with the image, progress bar at 100%
241
  yield gr.Image(image_paths[0])
242
- return
243
 
244
  tts_prefix = "@tts"
245
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -248,13 +242,16 @@ def generate(
248
  if is_tts and voice_index:
249
  voice = TTS_VOICES[voice_index - 1]
250
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
251
  conversation = [{"role": "user", "content": text}]
252
  else:
253
  voice = None
 
254
  text = text.replace(tts_prefix, "").strip()
255
  conversation = clean_chat_history(chat_history)
256
  conversation.append({"role": "user", "content": text})
257
 
 
258
  if files:
259
  if len(files) > 1:
260
  images = [load_image(image) for image in files]
@@ -276,18 +273,18 @@ def generate(
276
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
277
  thread.start()
278
 
279
- # Initial loading bar (indeterminate animation via CSS)
280
- yield gr.HTML(
281
- '<div>Generating response...</div>'
282
- '<progress class="progress-bar" style="width:100%; height:4px; background-color:#e0e0e0;"></progress>'
283
- )
284
  buffer = ""
 
 
285
  for new_text in streamer:
286
  buffer += new_text
287
  buffer = buffer.replace("<|im_end|>", "")
288
  time.sleep(0.01)
289
- # Yield only the text, replacing the loading bar
290
- yield buffer
 
 
 
291
  else:
292
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
293
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
@@ -306,21 +303,21 @@ def generate(
306
  "num_beams": 1,
307
  "repetition_penalty": repetition_penalty,
308
  }
309
- t = Thread(target=model.generate, kwargs=generation_kwargs)
310
- t.start()
311
 
312
- # Initial loading bar
313
- yield gr.HTML(
314
- '<div>Generating response...</div>'
315
- '<progress class="progress-bar" style="width:100%; height:4px; background-color:#e0e0e0;"></progress>'
316
- )
317
- buffer = ""
318
  for new_text in streamer:
319
- buffer += new_text
320
- # Yield only the text, replacing the loading bar
321
- yield buffer
322
-
323
- final_response = buffer
 
 
 
324
  if is_tts and voice:
325
  output_file = asyncio.run(text_to_speech(final_response, voice))
326
  yield gr.Audio(output_file, autoplay=True)
@@ -353,4 +350,4 @@ demo = gr.ChatInterface(
353
  )
354
 
355
  if __name__ == "__main__":
356
- demo.queue(max_size=20).launch(share=True)
 
25
 
26
  DESCRIPTION = """
27
  # QwQ Edge 💬
 
28
  """
29
 
30
  css = '''
 
39
  background: #1565c0;
40
  border-radius: 100vh;
41
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  '''
43
 
44
  MAX_MAX_NEW_TOKENS = 2048
 
47
 
48
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
 
50
+ # Function to return an HTML snippet of a thin animated progress bar
51
+ def progress_bar_html(message: str) -> str:
52
+ return f"""
53
+ <div style="display: flex; align-items: center;">
54
+ <span style="margin-right: 8px;">{message}</span>
55
+ <div style="position: relative; width: 110px; height: 5px; background-color: #f8d7da; border-radius: 2px; overflow: hidden;">
56
+ <div style="position: absolute; width: 100%; height: 100%; background-color: #f5c6cb; animation: loading 1.5s linear infinite;"></div>
57
+ </div>
58
+ </div>
59
+ <style>
60
+ @keyframes loading {{
61
+ 0% {{ transform: translateX(-100%); }}
62
+ 100% {{ transform: translateX(100%); }}
63
+ }}
64
+ </style>
65
+ """
66
+
67
  # Load text-only model and tokenizer
68
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
69
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
79
  "en-US-GuyNeural", # @tts2
80
  ]
81
 
82
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
83
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
84
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
85
  MODEL_ID,
 
94
  return output_file
95
 
96
  def clean_chat_history(chat_history):
97
+ """
98
+ Filter out any chat entries whose "content" is not a string.
99
+ This helps prevent errors when concatenating previous messages.
100
+ """
101
  cleaned = []
102
  for msg in chat_history:
103
  if isinstance(msg, dict) and isinstance(msg.get("content"), str):
104
  cleaned.append(msg)
105
  return cleaned
106
 
107
+ # Environment variables and parameters for Stable Diffusion XL
108
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
109
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
110
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
111
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
112
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
113
 
114
+ # Load the SDXL pipeline
115
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
116
  MODEL_ID_SD,
117
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
120
  ).to(device)
121
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
122
 
123
+ # Ensure that the text encoder is in half-precision if using CUDA.
124
  if torch.cuda.is_available():
125
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
126
 
127
+ # Optional: compile the model for speedup if enabled
128
  if USE_TORCH_COMPILE:
129
  sd_pipe.compile()
130
 
131
+ # Optional: offload parts of the model to CPU if needed
132
  if ENABLE_CPU_OFFLOAD:
133
  sd_pipe.enable_model_cpu_offload()
134
 
135
  MAX_SEED = np.iinfo(np.int32).max
136
 
137
  def save_image(img: Image.Image) -> str:
138
+ """Save a PIL image with a unique filename and return the path."""
139
  unique_name = str(uuid.uuid4()) + ".png"
140
  img.save(unique_name)
141
  return unique_name
 
160
  num_images: int = 1,
161
  progress=gr.Progress(track_tqdm=True),
162
  ):
163
+ """Generate images using the SDXL pipeline."""
164
  seed = int(randomize_seed_fn(seed, randomize_seed))
165
  generator = torch.Generator(device=device).manual_seed(seed)
166
 
 
178
  options["use_resolution_binning"] = True
179
 
180
  images = []
181
+ # Process in batches
182
  for i in range(0, num_images, BATCH_SIZE):
183
  batch_options = options.copy()
184
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
185
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
186
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
187
+ # Wrap the pipeline call in autocast if using CUDA
188
  if device.type == "cuda":
189
  with torch.autocast("cuda", dtype=torch.float16):
190
  outputs = sd_pipe(**batch_options)
 
213
  text = input_dict["text"]
214
  files = input_dict.get("files", [])
215
 
216
+ # Handle image generation command
217
  if text.strip().lower().startswith("@image"):
218
  prompt = text[len("@image"):].strip()
219
+ # Show animated progress bar for image generation
220
+ yield gr.HTML(progress_bar_html("Generating Image"))
 
 
 
 
221
  image_paths, used_seed = generate_image_fn(
222
  prompt=prompt,
223
  negative_prompt="",
 
231
  use_resolution_binning=True,
232
  num_images=1,
233
  )
234
+ # Replace the progress bar with the generated image
235
  yield gr.Image(image_paths[0])
236
+ return # Exit early
237
 
238
  tts_prefix = "@tts"
239
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
242
  if is_tts and voice_index:
243
  voice = TTS_VOICES[voice_index - 1]
244
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
245
+ # Clear previous chat history for a fresh TTS request.
246
  conversation = [{"role": "user", "content": text}]
247
  else:
248
  voice = None
249
+ # Remove any stray @tts tags and build the conversation history.
250
  text = text.replace(tts_prefix, "").strip()
251
  conversation = clean_chat_history(chat_history)
252
  conversation.append({"role": "user", "content": text})
253
 
254
+ # For multimodal chat with files (e.g. image + text)
255
  if files:
256
  if len(files) > 1:
257
  images = [load_image(image) for image in files]
 
273
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
274
  thread.start()
275
 
 
 
 
 
 
276
  buffer = ""
277
+ # Show progress bar for thinking
278
+ yield gr.HTML(progress_bar_html("Thinking..."))
279
  for new_text in streamer:
280
  buffer += new_text
281
  buffer = buffer.replace("<|im_end|>", "")
282
  time.sleep(0.01)
283
+ # Update with current text plus progress bar
284
+ interim_html = f"<div>{buffer}</div><div>{progress_bar_html('Thinking...')}</div>"
285
+ yield gr.HTML(interim_html)
286
+ # Final output without the progress bar
287
+ yield gr.HTML(f"<div>{buffer}</div>")
288
  else:
289
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
290
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
 
303
  "num_beams": 1,
304
  "repetition_penalty": repetition_penalty,
305
  }
306
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
307
+ thread.start()
308
 
309
+ outputs = []
310
+ # Show progress bar for thinking
311
+ yield gr.HTML(progress_bar_html("Thinking..."))
 
 
 
312
  for new_text in streamer:
313
+ outputs.append(new_text)
314
+ interim_html = f"<div>{''.join(outputs)}</div><div>{progress_bar_html('Thinking...')}</div>"
315
+ yield gr.HTML(interim_html)
316
+ final_response = "".join(outputs)
317
+ # Final output without progress bar
318
+ yield gr.HTML(f"<div>{final_response}</div>")
319
+
320
+ # If TTS was requested, convert the final response to speech.
321
  if is_tts and voice:
322
  output_file = asyncio.run(text_to_speech(final_response, voice))
323
  yield gr.Audio(output_file, autoplay=True)
 
350
  )
351
 
352
  if __name__ == "__main__":
353
+ demo.queue(max_size=20).launch(share=True