prithivMLmods commited on
Commit
7fc6af3
·
verified ·
1 Parent(s): 9e2bd4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -39
app.py CHANGED
@@ -4,6 +4,7 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
@@ -47,7 +48,32 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
47
 
48
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
 
50
- # Load text-only model and tokenizer for text generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
52
  tokenizer = AutoTokenizer.from_pretrained(model_id)
53
  model = AutoModelForCausalLM.from_pretrained(
@@ -62,6 +88,9 @@ TTS_VOICES = [
62
  "en-US-GuyNeural", # @tts2
63
  ]
64
 
 
 
 
65
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
66
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
67
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -79,7 +108,6 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
79
  def clean_chat_history(chat_history):
80
  """
81
  Filter out any chat entries whose "content" is not a string.
82
- This helps prevent errors when concatenating previous messages.
83
  """
84
  cleaned = []
85
  for msg in chat_history:
@@ -87,9 +115,9 @@ def clean_chat_history(chat_history):
87
  cleaned.append(msg)
88
  return cleaned
89
 
90
- # ------------------------------
91
- # New Image Generation Pipeline
92
- # ------------------------------
93
 
94
  MAX_SEED = np.iinfo(np.int32).max
95
  USE_TORCH_COMPILE = False
@@ -124,6 +152,12 @@ if torch.cuda.is_available():
124
  for model_name, weight_name, adapter_name in LORA_OPTIONS.values():
125
  pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
126
  pipe.to("cuda")
 
 
 
 
 
 
127
 
128
  def save_image(img: Image.Image) -> str:
129
  """Save a PIL image with a unique filename and return the path."""
@@ -167,10 +201,9 @@ def generate_image(
167
  image_paths = [save_image(img) for img in images]
168
  return image_paths, seed
169
 
170
- # ------------------------------
171
- # QwQ Edge Chat Interface
172
- # ------------------------------
173
-
174
  @spaces.GPU
175
  def generate(
176
  input_dict: dict,
@@ -193,13 +226,12 @@ def generate(
193
  files = input_dict.get("files", [])
194
 
195
  # Check for image generation command based on LoRA tags.
196
- # Build a mapping with lowercase keys.
197
  lora_mapping = { key.lower(): key for key in LORA_OPTIONS }
198
  for key_lower, key in lora_mapping.items():
199
  command_tag = "@" + key_lower
200
  if text.strip().lower().startswith(command_tag):
201
  prompt_text = text.strip()[len(command_tag):].strip()
202
- yield f" > Processing Image Generation {key} style ███████▒▒▒ 69%"
203
  image_paths, used_seed = generate_image(
204
  prompt=prompt_text,
205
  negative_prompt="",
@@ -210,7 +242,7 @@ def generate(
210
  randomize_seed=True,
211
  lora_model=key,
212
  )
213
- yield " > Processing Image Generation ████████▒▒ 90%"
214
  yield gr.Image(image_paths[0])
215
  return
216
 
@@ -222,15 +254,13 @@ def generate(
222
  if is_tts and voice_index:
223
  voice = TTS_VOICES[voice_index - 1]
224
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
225
- # Clear previous chat history for a fresh TTS request.
226
  conversation = [{"role": "user", "content": text}]
227
  else:
228
  voice = None
229
- # Remove any stray @tts tags and build the conversation history.
230
  text = text.replace(tts_prefix, "").strip()
231
  conversation = clean_chat_history(chat_history)
232
  conversation.append({"role": "user", "content": text})
233
-
234
  if files:
235
  if len(files) > 1:
236
  images = [load_image(image) for image in files]
@@ -253,7 +283,7 @@ def generate(
253
  thread.start()
254
 
255
  buffer = ""
256
- yield " > Processing with Qwen2VL Ocr ███████▒▒▒ 69%"
257
  for new_text in streamer:
258
  buffer += new_text
259
  buffer = buffer.replace("<|im_end|>", "")
@@ -288,12 +318,13 @@ def generate(
288
  final_response = "".join(outputs)
289
  yield final_response
290
 
291
- # If TTS was requested, convert the final response to speech.
292
  if is_tts and voice:
293
  output_file = asyncio.run(text_to_speech(final_response, voice))
294
  yield gr.Audio(output_file, autoplay=True)
295
 
296
-
 
 
297
  demo = gr.ChatInterface(
298
  fn=generate,
299
  additional_inputs=[
@@ -303,26 +334,25 @@ demo = gr.ChatInterface(
303
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
304
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
305
  ],
306
- examples = [
307
-
308
- ["@realism Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
309
- ["@pixar A young man with light brown wavy hair and light brown eyes sitting in an armchair and looking directly at the camera, pixar style, disney pixar, office background, ultra detailed, 1 man"],
310
- ["@realism A futuristic cityscape with neon lights"],
311
- ["@photoshoot A portrait of a person with dramatic lighting"],
312
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
- ["Python Program for Array Rotation"],
314
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
315
- ["@clothing Fashionable streetwear in an urban environment"],
316
- ["@interior A modern living room interior with minimalist design"],
317
- ["@fashion A runway model in haute couture"],
318
- ["@minimalistic A simple and elegant design of a serene landscape"],
319
- ["@modern A contemporary art piece with abstract geometric shapes"],
320
- ["@animaliea A cute animal portrait with vibrant colors"],
321
- ["@wallpaper A scenic mountain range perfect for a desktop wallpaper"],
322
- ["@cars A sleek sports car cruising on a city street"],
323
- ["@pencilart A detailed pencil sketch of a historic building"],
324
- ["@artminimalistic An artistic minimalist composition with subtle tones"],
325
- ["@tts2 What causes rainbows to form?"],
326
  ],
327
  cache_examples=False,
328
  type="messages",
@@ -335,5 +365,4 @@ demo = gr.ChatInterface(
335
  )
336
 
337
  if __name__ == "__main__":
338
- # To create a public link, set share=True in launch().
339
  demo.queue(max_size=20).launch(share=True)
 
4
  import json
5
  import time
6
  import asyncio
7
+ import re
8
  from threading import Thread
9
 
10
  import gradio as gr
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
+ # -----------------------
52
+ # Progress Bar Helper
53
+ # -----------------------
54
+ def progress_bar_html(label: str) -> str:
55
+ """
56
+ Returns an HTML snippet for a thin progress bar with a label.
57
+ The progress bar is styled as a dark red animated bar.
58
+ """
59
+ return f'''
60
+ <div style="display: flex; align-items: center;">
61
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
62
+ <div style="width: 110px; height: 5px; background-color: #f0f0f0; border-radius: 2px; overflow: hidden;">
63
+ <div style="width: 100%; height: 100%; background-color: #FF00FF; animation: loading 1.5s linear infinite;"></div>
64
+ </div>
65
+ </div>
66
+ <style>
67
+ @keyframes loading {{
68
+ 0% {{ transform: translateX(-100%); }}
69
+ 100% {{ transform: translateX(100%); }}
70
+ }}
71
+ </style>
72
+ '''
73
+
74
+ # -----------------------
75
+ # Text Generation Setup
76
+ # -----------------------
77
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
78
  tokenizer = AutoTokenizer.from_pretrained(model_id)
79
  model = AutoModelForCausalLM.from_pretrained(
 
88
  "en-US-GuyNeural", # @tts2
89
  ]
90
 
91
+ # -----------------------
92
+ # Multimodal OCR Setup
93
+ # -----------------------
94
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
95
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
96
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
108
  def clean_chat_history(chat_history):
109
  """
110
  Filter out any chat entries whose "content" is not a string.
 
111
  """
112
  cleaned = []
113
  for msg in chat_history:
 
115
  cleaned.append(msg)
116
  return cleaned
117
 
118
+ # -----------------------
119
+ # Stable Diffusion Image Generation Setup
120
+ # -----------------------
121
 
122
  MAX_SEED = np.iinfo(np.int32).max
123
  USE_TORCH_COMPILE = False
 
152
  for model_name, weight_name, adapter_name in LORA_OPTIONS.values():
153
  pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
154
  pipe.to("cuda")
155
+ else:
156
+ pipe = StableDiffusionXLPipeline.from_pretrained(
157
+ "SG161222/RealVisXL_V4.0_Lightning",
158
+ torch_dtype=torch.float32,
159
+ use_safetensors=True,
160
+ ).to(device)
161
 
162
  def save_image(img: Image.Image) -> str:
163
  """Save a PIL image with a unique filename and return the path."""
 
201
  image_paths = [save_image(img) for img in images]
202
  return image_paths, seed
203
 
204
+ # -----------------------
205
+ # Main Chat/Generation Function
206
+ # -----------------------
 
207
  @spaces.GPU
208
  def generate(
209
  input_dict: dict,
 
226
  files = input_dict.get("files", [])
227
 
228
  # Check for image generation command based on LoRA tags.
 
229
  lora_mapping = { key.lower(): key for key in LORA_OPTIONS }
230
  for key_lower, key in lora_mapping.items():
231
  command_tag = "@" + key_lower
232
  if text.strip().lower().startswith(command_tag):
233
  prompt_text = text.strip()[len(command_tag):].strip()
234
+ yield progress_bar_html(f"Processing Image Generation ({key} style)")
235
  image_paths, used_seed = generate_image(
236
  prompt=prompt_text,
237
  negative_prompt="",
 
242
  randomize_seed=True,
243
  lora_model=key,
244
  )
245
+ yield progress_bar_html("Finalizing Image Generation")
246
  yield gr.Image(image_paths[0])
247
  return
248
 
 
254
  if is_tts and voice_index:
255
  voice = TTS_VOICES[voice_index - 1]
256
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
257
  conversation = [{"role": "user", "content": text}]
258
  else:
259
  voice = None
 
260
  text = text.replace(tts_prefix, "").strip()
261
  conversation = clean_chat_history(chat_history)
262
  conversation.append({"role": "user", "content": text})
263
+
264
  if files:
265
  if len(files) > 1:
266
  images = [load_image(image) for image in files]
 
283
  thread.start()
284
 
285
  buffer = ""
286
+ yield progress_bar_html("Processing with Qwen2VL Ocr")
287
  for new_text in streamer:
288
  buffer += new_text
289
  buffer = buffer.replace("<|im_end|>", "")
 
318
  final_response = "".join(outputs)
319
  yield final_response
320
 
 
321
  if is_tts and voice:
322
  output_file = asyncio.run(text_to_speech(final_response, voice))
323
  yield gr.Audio(output_file, autoplay=True)
324
 
325
+ # -----------------------
326
+ # Gradio Chat Interface
327
+ # -----------------------
328
  demo = gr.ChatInterface(
329
  fn=generate,
330
  additional_inputs=[
 
334
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
335
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
336
  ],
337
+ examples=[
338
+ ['@realism Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic'],
339
+ ["@pixar A young man with light brown wavy hair and light brown eyes sitting in an armchair and looking directly at the camera, pixar style, disney pixar, office background, ultra detailed, 1 man"],
340
+ ["@realism A futuristic cityscape with neon lights"],
341
+ ["@photoshoot A portrait of a person with dramatic lighting"],
342
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
343
+ ["Python Program for Array Rotation"],
344
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
345
+ ["@clothing Fashionable streetwear in an urban environment"],
346
+ ["@interior A modern living room interior with minimalist design"],
347
+ ["@fashion A runway model in haute couture"],
348
+ ["@minimalistic A simple and elegant design of a serene landscape"],
349
+ ["@modern A contemporary art piece with abstract geometric shapes"],
350
+ ["@animaliea A cute animal portrait with vibrant colors"],
351
+ ["@wallpaper A scenic mountain range perfect for a desktop wallpaper"],
352
+ ["@cars A sleek sports car cruising on a city street"],
353
+ ["@pencilart A detailed pencil sketch of a historic building"],
354
+ ["@artminimalistic An artistic minimalist composition with subtle tones"],
355
+ ["@tts2 What causes rainbows to form?"],
 
356
  ],
357
  cache_examples=False,
358
  type="messages",
 
365
  )
366
 
367
  if __name__ == "__main__":
 
368
  demo.queue(max_size=20).launch(share=True)