prithivMLmods commited on
Commit
fda00e3
·
verified ·
1 Parent(s): 7fcd908

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -132
app.py CHANGED
@@ -26,7 +26,6 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
  DESCRIPTION = """
28
  # Gen Vision 🎃
29
- Separate Tabs for Chat, Image Generation (LoRA), Qwen2 VL OCR and Text-to-Speech
30
  """
31
 
32
  css = '''
@@ -73,7 +72,7 @@ def progress_bar_html(label: str) -> str:
73
  '''
74
 
75
  # -----------------------
76
- # Text Generation Setup (Chat)
77
  # -----------------------
78
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
79
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -84,23 +83,28 @@ model = AutoModelForCausalLM.from_pretrained(
84
  )
85
  model.eval()
86
 
87
- # -----------------------
88
- # TTS Setup
89
- # -----------------------
90
  TTS_VOICES = [
91
- "en-US-JennyNeural",
92
- "en-US-GuyNeural",
93
  ]
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
96
  """Convert text to speech using Edge TTS and save as MP3"""
97
  communicate = edge_tts.Communicate(text, voice)
98
  await communicate.save(output_file)
99
  return output_file
100
 
101
- # -----------------------
102
- # Utility: Clean Chat History
103
- # -----------------------
104
  def clean_chat_history(chat_history):
105
  """
106
  Filter out any chat entries whose "content" is not a string.
@@ -112,19 +116,9 @@ def clean_chat_history(chat_history):
112
  return cleaned
113
 
114
  # -----------------------
115
- # Qwen2 VL OCR Setup
116
  # -----------------------
117
- OCR_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
118
- processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
119
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
120
- OCR_MODEL_ID,
121
- trust_remote_code=True,
122
- torch_dtype=torch.float16
123
- ).to("cuda").eval()
124
 
125
- # -----------------------
126
- # Stable Diffusion Image Generation Setup (LoRA)
127
- # -----------------------
128
  MAX_SEED = np.iinfo(np.int32).max
129
  USE_TORCH_COMPILE = False
130
  ENABLE_CPU_OFFLOAD = False
@@ -177,7 +171,17 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
177
  return seed
178
 
179
  @spaces.GPU(duration=180, enable_queue=True)
180
- def generate_image(prompt: str, negative_prompt: str, seed: int, width: int, height: int, guidance_scale: float, randomize_seed: bool, lora_model: str):
 
 
 
 
 
 
 
 
 
 
181
  seed = int(randomize_seed_fn(seed, randomize_seed))
182
  effective_negative_prompt = negative_prompt # Use provided negative prompt if any
183
  model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
@@ -198,51 +202,78 @@ def generate_image(prompt: str, negative_prompt: str, seed: int, width: int, hei
198
  return image_paths, seed
199
 
200
  # -----------------------
201
- # Chat Generation Function (Text-only)
202
- # -----------------------
203
- def generate_chat(input_text: str, chat_history: list, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
204
- conversation = clean_chat_history(chat_history)
205
- conversation.append({"role": "user", "content": input_text})
206
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
207
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
208
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
209
- input_ids = input_ids.to(model.device)
210
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
211
- generation_kwargs = {
212
- "input_ids": input_ids,
213
- "streamer": streamer,
214
- "max_new_tokens": max_new_tokens,
215
- "do_sample": True,
216
- "top_p": top_p,
217
- "top_k": top_k,
218
- "temperature": temperature,
219
- "num_beams": 1,
220
- "repetition_penalty": repetition_penalty,
221
- }
222
- t = Thread(target=model.generate, kwargs=generation_kwargs)
223
- t.start()
224
- outputs = []
225
- for new_text in streamer:
226
- outputs.append(new_text)
227
- final_response = "".join(outputs)
228
- chat_history.append({"role": "assistant", "content": final_response})
229
- return chat_history
230
-
231
- # -----------------------
232
- # Qwen2 VL OCR Function (Multimodal)
233
  # -----------------------
234
- def generate_ocr(text: str, files, max_new_tokens: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if files:
236
- if isinstance(files, list) and len(files) > 1:
237
  images = [load_image(image) for image in files]
238
- elif isinstance(files, list) and len(files) == 1:
239
  images = [load_image(files[0])]
240
  else:
241
- images = [load_image(files)]
242
  messages = [{
243
  "role": "user",
244
- "content": [*([{"type": "image", "image": image} for image in images]),
245
- {"type": "text", "text": text}]
 
 
246
  }]
247
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
248
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
@@ -250,84 +281,88 @@ def generate_ocr(text: str, files, max_new_tokens: int):
250
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
251
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
252
  thread.start()
 
253
  buffer = ""
 
254
  for new_text in streamer:
255
  buffer += new_text
256
- return buffer
 
 
257
  else:
258
- return "No images provided."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
- # -----------------------
261
- # Text-to-Speech Function
262
- # -----------------------
263
- def generate_tts(text: str, voice: str):
264
- output_file = asyncio.run(text_to_speech(text, voice))
265
- return output_file
 
 
 
 
 
266
 
267
  # -----------------------
268
- # Gradio Interface with Tabs
269
  # -----------------------
270
- with gr.Blocks(css=css, title="Gen Vision") as demo:
271
- gr.Markdown(DESCRIPTION)
272
-
273
- with gr.Tab("Chat Interface"):
274
- with gr.Row():
275
- chat_history = gr.Chatbot(label="Chat History")
276
- with gr.Row():
277
- chat_input = gr.Textbox(placeholder="Enter your message", label="Your Message")
278
- with gr.Row():
279
- max_new_tokens_slider = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
280
- temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
281
- with gr.Row():
282
- top_p_slider = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
283
- top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
284
- repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
285
- send_btn = gr.Button("Send")
286
- send_btn.click(
287
- fn=generate_chat,
288
- inputs=[chat_input, chat_history, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, repetition_penalty_slider],
289
- outputs=chat_history,
290
- )
291
-
292
- with gr.Tab("Image Generation"):
293
- image_prompt = gr.Textbox(label="Prompt", placeholder="Enter image prompt")
294
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt")
295
- seed_input = gr.Number(label="Seed", value=0)
296
- width_slider = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=1024)
297
- height_slider = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=1024)
298
- guidance_scale_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=3.0)
299
- randomize_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
300
- lora_dropdown = gr.Dropdown(label="LoRA Style", choices=list(LORA_OPTIONS.keys()), value="Realism")
301
- generate_img_btn = gr.Button("Generate Image")
302
- img_output = gr.Image(label="Generated Image")
303
- seed_output = gr.Number(label="Used Seed")
304
- generate_img_btn.click(
305
- fn=generate_image,
306
- inputs=[image_prompt, negative_prompt, seed_input, width_slider, height_slider, guidance_scale_slider, randomize_checkbox, lora_dropdown],
307
- outputs=[img_output, seed_output],
308
- )
309
-
310
- with gr.Tab("Qwen 2 VL OCR"):
311
- ocr_text = gr.Textbox(label="Text Prompt", placeholder="Enter prompt for OCR")
312
- file_input = gr.File(label="Upload Images", file_count="multiple")
313
- ocr_max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
314
- ocr_btn = gr.Button("Run OCR")
315
- ocr_output = gr.Textbox(label="OCR Output")
316
- ocr_btn.click(
317
- fn=generate_ocr,
318
- inputs=[ocr_text, file_input, ocr_max_new_tokens],
319
- outputs=ocr_output,
320
- )
321
-
322
- with gr.Tab("Text-to-Speech"):
323
- tts_text = gr.Textbox(label="Text", placeholder="Enter text for TTS")
324
- voice_dropdown = gr.Dropdown(label="Voice", choices=TTS_VOICES, value=TTS_VOICES[0])
325
- tts_btn = gr.Button("Generate Audio")
326
- tts_audio = gr.Audio(label="Audio Output", type="filepath")
327
- tts_btn.click(
328
- fn=generate_tts,
329
- inputs=[tts_text, voice_dropdown],
330
- outputs=tts_audio,
331
- )
332
 
333
- demo.queue(max_size=20).launch(share=True)
 
 
26
 
27
  DESCRIPTION = """
28
  # Gen Vision 🎃
 
29
  """
30
 
31
  css = '''
 
72
  '''
73
 
74
  # -----------------------
75
+ # Text Generation Setup
76
  # -----------------------
77
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
78
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
83
  )
84
  model.eval()
85
 
 
 
 
86
  TTS_VOICES = [
87
+ "en-US-JennyNeural", # @tts1
88
+ "en-US-GuyNeural", # @tts2
89
  ]
90
 
91
+ # -----------------------
92
+ # Multimodal OCR Setup
93
+ # -----------------------
94
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
95
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
96
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
97
+ MODEL_ID,
98
+ trust_remote_code=True,
99
+ torch_dtype=torch.float16
100
+ ).to("cuda").eval()
101
+
102
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
103
  """Convert text to speech using Edge TTS and save as MP3"""
104
  communicate = edge_tts.Communicate(text, voice)
105
  await communicate.save(output_file)
106
  return output_file
107
 
 
 
 
108
  def clean_chat_history(chat_history):
109
  """
110
  Filter out any chat entries whose "content" is not a string.
 
116
  return cleaned
117
 
118
  # -----------------------
119
+ # Stable Diffusion Image Generation Setup
120
  # -----------------------
 
 
 
 
 
 
 
121
 
 
 
 
122
  MAX_SEED = np.iinfo(np.int32).max
123
  USE_TORCH_COMPILE = False
124
  ENABLE_CPU_OFFLOAD = False
 
171
  return seed
172
 
173
  @spaces.GPU(duration=180, enable_queue=True)
174
+ def generate_image(
175
+ prompt: str,
176
+ negative_prompt: str = "",
177
+ seed: int = 0,
178
+ width: int = 1024,
179
+ height: int = 1024,
180
+ guidance_scale: float = 3.0,
181
+ randomize_seed: bool = True,
182
+ lora_model: str = "Realism",
183
+ progress=gr.Progress(track_tqdm=True),
184
+ ):
185
  seed = int(randomize_seed_fn(seed, randomize_seed))
186
  effective_negative_prompt = negative_prompt # Use provided negative prompt if any
187
  model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
 
202
  return image_paths, seed
203
 
204
  # -----------------------
205
+ # Main Chat/Generation Function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  # -----------------------
207
+ @spaces.GPU
208
+ def generate(
209
+ input_dict: dict,
210
+ chat_history: list[dict],
211
+ max_new_tokens: int = 1024,
212
+ temperature: float = 0.6,
213
+ top_p: float = 0.9,
214
+ top_k: int = 50,
215
+ repetition_penalty: float = 1.2,
216
+ ):
217
+ """
218
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
219
+ Special commands:
220
+ - "@tts1" or "@tts2": triggers text-to-speech.
221
+ - "@<lora_command>": triggers image generation using the new LoRA pipeline.
222
+ Available commands (case-insensitive): @realism, @pixar, @photoshoot, @clothing, @interior, @fashion,
223
+ @minimalistic, @modern, @animaliea, @wallpaper, @cars, @pencilart, @artminimalistic.
224
+ """
225
+ text = input_dict["text"]
226
+ files = input_dict.get("files", [])
227
+
228
+ # Check for image generation command based on LoRA tags.
229
+ lora_mapping = { key.lower(): key for key in LORA_OPTIONS }
230
+ for key_lower, key in lora_mapping.items():
231
+ command_tag = "@" + key_lower
232
+ if text.strip().lower().startswith(command_tag):
233
+ prompt_text = text.strip()[len(command_tag):].strip()
234
+ yield progress_bar_html(f"Processing Image Generation ({key} style)")
235
+ image_paths, used_seed = generate_image(
236
+ prompt=prompt_text,
237
+ negative_prompt="",
238
+ seed=1,
239
+ width=1024,
240
+ height=1024,
241
+ guidance_scale=3,
242
+ randomize_seed=True,
243
+ lora_model=key,
244
+ )
245
+ yield progress_bar_html("Finalizing Image Generation")
246
+ yield gr.Image(image_paths[0])
247
+ return
248
+
249
+ # Check for TTS command (@tts1 or @tts2)
250
+ tts_prefix = "@tts"
251
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
252
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
253
+
254
+ if is_tts and voice_index:
255
+ voice = TTS_VOICES[voice_index - 1]
256
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
257
+ conversation = [{"role": "user", "content": text}]
258
+ else:
259
+ voice = None
260
+ text = text.replace(tts_prefix, "").strip()
261
+ conversation = clean_chat_history(chat_history)
262
+ conversation.append({"role": "user", "content": text})
263
+
264
  if files:
265
+ if len(files) > 1:
266
  images = [load_image(image) for image in files]
267
+ elif len(files) == 1:
268
  images = [load_image(files[0])]
269
  else:
270
+ images = []
271
  messages = [{
272
  "role": "user",
273
+ "content": [
274
+ *[{"type": "image", "image": image} for image in images],
275
+ {"type": "text", "text": text},
276
+ ]
277
  }]
278
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
279
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
 
281
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
282
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
283
  thread.start()
284
+
285
  buffer = ""
286
+ yield progress_bar_html("Processing with Qwen2VL Ocr")
287
  for new_text in streamer:
288
  buffer += new_text
289
+ buffer = buffer.replace("<|im_end|>", "")
290
+ time.sleep(0.01)
291
+ yield buffer
292
  else:
293
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
294
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
295
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
296
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
297
+ input_ids = input_ids.to(model.device)
298
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
299
+ generation_kwargs = {
300
+ "input_ids": input_ids,
301
+ "streamer": streamer,
302
+ "max_new_tokens": max_new_tokens,
303
+ "do_sample": True,
304
+ "top_p": top_p,
305
+ "top_k": top_k,
306
+ "temperature": temperature,
307
+ "num_beams": 1,
308
+ "repetition_penalty": repetition_penalty,
309
+ }
310
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
311
+ t.start()
312
 
313
+ outputs = []
314
+ for new_text in streamer:
315
+ outputs.append(new_text)
316
+ yield "".join(outputs)
317
+
318
+ final_response = "".join(outputs)
319
+ yield final_response
320
+
321
+ if is_tts and voice:
322
+ output_file = asyncio.run(text_to_speech(final_response, voice))
323
+ yield gr.Audio(output_file, autoplay=True)
324
 
325
  # -----------------------
326
+ # Gradio Chat Interface
327
  # -----------------------
328
+ demo = gr.ChatInterface(
329
+ fn=generate,
330
+ additional_inputs=[
331
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
332
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
333
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
334
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
335
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
336
+ ],
337
+ examples=[
338
+ ['@realism Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic'],
339
+ ["@pixar A young man with light brown wavy hair and light brown eyes sitting in an armchair and looking directly at the camera, pixar style, disney pixar, office background, ultra detailed, 1 man"],
340
+ ["@realism A futuristic cityscape with neon lights"],
341
+ ["@photoshoot A portrait of a person with dramatic lighting"],
342
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
343
+ ["Python Program for Array Rotation"],
344
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
345
+ ["@clothing Fashionable streetwear in an urban environment"],
346
+ ["@interior A modern living room interior with minimalist design"],
347
+ ["@fashion A runway model in haute couture"],
348
+ ["@minimalistic A simple and elegant design of a serene landscape"],
349
+ ["@modern A contemporary art piece with abstract geometric shapes"],
350
+ ["@animaliea A cute animal portrait with vibrant colors"],
351
+ ["@wallpaper A scenic mountain range perfect for a desktop wallpaper"],
352
+ ["@cars A sleek sports car cruising on a city street"],
353
+ ["@pencilart A detailed pencil sketch of a historic building"],
354
+ ["@artminimalistic An artistic minimalist composition with subtle tones"],
355
+ ["@tts2 What causes rainbows to form?"],
356
+ ],
357
+ cache_examples=False,
358
+ type="messages",
359
+ description=DESCRIPTION,
360
+ css=css,
361
+ fill_height=True,
362
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="default [text, vision] , scroll down examples to explore more art styles"),
363
+ stop_btn="Stop Generation",
364
+ multimodal=True,
365
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
+ if __name__ == "__main__":
368
+ demo.queue(max_size=20).launch(share=True)