prithivMLmods commited on
Commit
54875b8
·
verified ·
1 Parent(s): bb78bca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -249
app.py CHANGED
@@ -4,7 +4,6 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
7
- import re
8
  from threading import Thread
9
 
10
  import gradio as gr
@@ -13,6 +12,7 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
 
16
 
17
  from transformers import (
18
  AutoModelForCausalLM,
@@ -24,56 +24,15 @@ from transformers import (
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
- DESCRIPTION = """
28
- # Gen Vision 🎃
29
- """
30
-
31
- css = '''
32
- h1 {
33
- text-align: center;
34
- display: block;
35
- }
36
-
37
- #duplicate-button {
38
- margin: auto;
39
- color: #fff;
40
- background: #1565c0;
41
- border-radius: 100vh;
42
- }
43
- '''
44
-
45
  MAX_MAX_NEW_TOKENS = 2048
46
  DEFAULT_MAX_NEW_TOKENS = 1024
47
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
- # -----------------------
52
- # Progress Bar Helper
53
- # -----------------------
54
- def progress_bar_html(label: str) -> str:
55
- """
56
- Returns an HTML snippet for a thin progress bar with a label.
57
- The progress bar is styled as a dark red animated bar.
58
- """
59
- return f'''
60
- <div style="display: flex; align-items: center;">
61
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
62
- <div style="width: 110px; height: 5px; background-color: #DDA0DD; border-radius: 2px; overflow: hidden;">
63
- <div style="width: 100%; height: 100%; background-color: #FF00FF; animation: loading 1.5s linear infinite;"></div>
64
- </div>
65
- </div>
66
- <style>
67
- @keyframes loading {{
68
- 0% {{ transform: translateX(-100%); }}
69
- 100% {{ transform: translateX(100%); }}
70
- }}
71
- </style>
72
- '''
73
-
74
- # -----------------------
75
- # Text Generation Setup
76
- # -----------------------
77
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
78
  tokenizer = AutoTokenizer.from_pretrained(model_id)
79
  model = AutoModelForCausalLM.from_pretrained(
@@ -83,170 +42,217 @@ model = AutoModelForCausalLM.from_pretrained(
83
  )
84
  model.eval()
85
 
 
86
  TTS_VOICES = [
87
  "en-US-JennyNeural", # @tts1
88
  "en-US-GuyNeural", # @tts2
89
  ]
90
 
91
- # -----------------------
92
- # Multimodal OCR Setup
93
- # -----------------------
94
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
95
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
96
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
97
- MODEL_ID,
98
  trust_remote_code=True,
99
  torch_dtype=torch.float16
100
  ).to("cuda").eval()
101
 
102
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
103
- """Convert text to speech using Edge TTS and save as MP3"""
104
- communicate = edge_tts.Communicate(text, voice)
105
- await communicate.save(output_file)
106
- return output_file
107
-
108
- def clean_chat_history(chat_history):
109
- """
110
- Filter out any chat entries whose "content" is not a string.
111
- """
112
- cleaned = []
113
- for msg in chat_history:
114
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
115
- cleaned.append(msg)
116
- return cleaned
117
 
118
- # -----------------------
119
- # Stable Diffusion Image Generation Setup
120
- # -----------------------
 
 
 
 
 
 
 
 
 
 
121
 
122
- MAX_SEED = np.iinfo(np.int32).max
123
- USE_TORCH_COMPILE = False
124
- ENABLE_CPU_OFFLOAD = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- if torch.cuda.is_available():
127
- pipe = StableDiffusionXLPipeline.from_pretrained(
128
- "SG161222/RealVisXL_V4.0_Lightning",
129
- torch_dtype=torch.float16,
130
- use_safetensors=True,
131
- )
132
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
133
-
134
- # LoRA options with one example for each.
135
- LORA_OPTIONS = {
136
- "Realism": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
137
- "Pixar": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
138
- "Photoshoot": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"),
139
- "Clothing": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"),
140
- "Interior": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1δ.safetensors", "arch"),
141
- "Fashion": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"),
142
- "Minimalistic": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"),
143
- "Modern": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"),
144
- "Animaliea": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"),
145
- "Wallpaper": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"),
146
- "Cars": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"),
147
- "PencilArt": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"),
148
- "ArtMinimalistic": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
149
- }
150
 
151
- # Load all LoRA weights
152
- for model_name, weight_name, adapter_name in LORA_OPTIONS.values():
153
- pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
154
- pipe.to("cuda")
155
- else:
156
- pipe = StableDiffusionXLPipeline.from_pretrained(
157
- "SG161222/RealVisXL_V4.0_Lightning",
158
- torch_dtype=torch.float32,
159
- use_safetensors=True,
160
- ).to(device)
161
 
162
  def save_image(img: Image.Image) -> str:
163
- """Save a PIL image with a unique filename and return the path."""
164
  unique_name = str(uuid.uuid4()) + ".png"
165
  img.save(unique_name)
166
  return unique_name
167
 
168
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
169
- if randomize_seed:
170
- seed = random.randint(0, MAX_SEED)
171
- return seed
172
 
173
- @spaces.GPU(duration=180, enable_queue=True)
174
- def generate_image(
175
- prompt: str,
176
- negative_prompt: str = "",
177
- seed: int = 0,
178
- width: int = 1024,
179
- height: int = 1024,
180
- guidance_scale: float = 3.0,
181
- randomize_seed: bool = True,
182
- lora_model: str = "Realism",
183
- progress=gr.Progress(track_tqdm=True),
184
- ):
185
- seed = int(randomize_seed_fn(seed, randomize_seed))
186
- effective_negative_prompt = negative_prompt # Use provided negative prompt if any
187
- model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
188
- pipe.set_adapters(adapter_name)
189
- outputs = pipe(
190
- prompt=prompt,
191
- negative_prompt=effective_negative_prompt,
192
- width=width,
193
- height=height,
194
- guidance_scale=guidance_scale,
195
- num_inference_steps=28,
196
- num_images_per_prompt=1,
197
- cross_attention_kwargs={"scale": 0.65},
198
- output_type="pil",
199
- )
200
- images = outputs.images
201
- image_paths = [save_image(img) for img in images]
202
- return image_paths, seed
 
 
 
 
 
 
 
 
 
203
 
204
- # -----------------------
205
- # Main Chat/Generation Function
206
- # -----------------------
207
- @spaces.GPU
208
- def generate(
209
- input_dict: dict,
210
- chat_history: list[dict],
211
- max_new_tokens: int = 1024,
212
- temperature: float = 0.6,
213
- top_p: float = 0.9,
214
- top_k: int = 50,
215
- repetition_penalty: float = 1.2,
216
- ):
217
- """
218
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
219
- Special commands:
220
- - "@tts1" or "@tts2": triggers text-to-speech.
221
- - "@<lora_command>": triggers image generation using the new LoRA pipeline.
222
- Available commands (case-insensitive): @realism, @pixar, @photoshoot, @clothing, @interior, @fashion,
223
- @minimalistic, @modern, @animaliea, @wallpaper, @cars, @pencilart, @artminimalistic.
224
- """
225
  text = input_dict["text"]
226
  files = input_dict.get("files", [])
227
-
228
- # Check for image generation command based on LoRA tags.
229
- lora_mapping = { key.lower(): key for key in LORA_OPTIONS }
230
- for key_lower, key in lora_mapping.items():
231
- command_tag = "@" + key_lower
232
- if text.strip().lower().startswith(command_tag):
233
- prompt_text = text.strip()[len(command_tag):].strip()
234
- yield progress_bar_html(f"Processing Image Generation ({key} style)")
235
- image_paths, used_seed = generate_image(
236
- prompt=prompt_text,
237
- negative_prompt="",
238
- seed=1,
239
- width=1024,
240
- height=1024,
241
- guidance_scale=3,
242
- randomize_seed=True,
243
- lora_model=key,
244
- )
245
- yield progress_bar_html("Finalizing Image Generation")
246
- yield gr.Image(image_paths[0])
247
- return
248
-
249
- # Check for TTS command (@tts1 or @tts2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  tts_prefix = "@tts"
251
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
252
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -260,40 +266,31 @@ def generate(
260
  text = text.replace(tts_prefix, "").strip()
261
  conversation = clean_chat_history(chat_history)
262
  conversation.append({"role": "user", "content": text})
263
-
264
  if files:
265
- if len(files) > 1:
266
- images = [load_image(image) for image in files]
267
- elif len(files) == 1:
268
- images = [load_image(files[0])]
269
- else:
270
- images = []
271
  messages = [{
272
  "role": "user",
273
- "content": [
274
- *[{"type": "image", "image": image} for image in images],
275
- {"type": "text", "text": text},
276
- ]
277
  }]
278
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
279
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
280
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
281
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
282
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
283
  thread.start()
284
-
285
  buffer = ""
286
- yield progress_bar_html("Processing with Qwen2VL Ocr")
287
  for new_text in streamer:
288
- buffer += new_text
289
- buffer = buffer.replace("<|im_end|>", "")
290
  time.sleep(0.01)
291
  yield buffer
292
  else:
293
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
294
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
295
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
296
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
297
  input_ids = input_ids.to(model.device)
298
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
299
  generation_kwargs = {
@@ -309,60 +306,164 @@ def generate(
309
  }
310
  t = Thread(target=model.generate, kwargs=generation_kwargs)
311
  t.start()
312
-
313
  outputs = []
 
314
  for new_text in streamer:
315
  outputs.append(new_text)
316
  yield "".join(outputs)
317
-
318
  final_response = "".join(outputs)
319
  yield final_response
320
-
321
  if is_tts and voice:
322
- output_file = asyncio.run(text_to_speech(final_response, voice))
323
- yield gr.Audio(output_file, autoplay=True)
324
 
325
- # -----------------------
326
- # Gradio Chat Interface
327
- # -----------------------
328
- demo = gr.ChatInterface(
329
- fn=generate,
330
- additional_inputs=[
331
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
332
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
333
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
334
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
335
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
336
- ],
337
- examples=[
338
- ['@realism Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic'],
339
- ["@pixar A young man with light brown wavy hair and light brown eyes sitting in an armchair and looking directly at the camera, pixar style, disney pixar, office background, ultra detailed, 1 man"],
340
- ["@realism A futuristic cityscape with neon lights"],
341
- ["@photoshoot A portrait of a person with dramatic lighting"],
342
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
343
- ["Python Program for Array Rotation"],
344
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
345
- ["@clothing Fashionable streetwear in an urban environment"],
346
- ["@interior A modern living room interior with minimalist design"],
347
- ["@fashion A runway model in haute couture"],
348
- ["@minimalistic A simple and elegant design of a serene landscape"],
349
- ["@modern A contemporary art piece with abstract geometric shapes"],
350
- ["@animaliea A cute animal portrait with vibrant colors"],
351
- ["@wallpaper A scenic mountain range perfect for a desktop wallpaper"],
352
- ["@cars A sleek sports car cruising on a city street"],
353
- ["@pencilart A detailed pencil sketch of a historic building"],
354
- ["@artminimalistic An artistic minimalist composition with subtle tones"],
355
- ["@tts2 What causes rainbows to form?"],
356
- ],
357
- cache_examples=False,
358
- type="messages",
359
- description=DESCRIPTION,
360
- css=css,
361
- fill_height=True,
362
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="default [text, vision] , scroll down examples to explore more art styles"),
363
- stop_btn="Stop Generation",
364
- multimodal=True,
365
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  if __name__ == "__main__":
368
- demo.queue(max_size=20).launch(share=True)
 
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
 
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
15
+ import cv2
16
 
17
  from transformers import (
18
  AutoModelForCausalLM,
 
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
+ # --------- Global Config and Model Loading ---------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  MAX_MAX_NEW_TOKENS = 2048
29
  DEFAULT_MAX_NEW_TOKENS = 1024
30
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
+ MAX_SEED = np.iinfo(np.int32).max
32
 
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
 
35
+ # For text-only generation (chat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  model = AutoModelForCausalLM.from_pretrained(
 
42
  )
43
  model.eval()
44
 
45
+ # For TTS
46
  TTS_VOICES = [
47
  "en-US-JennyNeural", # @tts1
48
  "en-US-GuyNeural", # @tts2
49
  ]
50
 
51
+ # For multimodal Qwen2VL (OCR / video/text)
52
+ MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
+ processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
 
 
54
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
55
+ MODEL_ID_QWEN,
56
  trust_remote_code=True,
57
  torch_dtype=torch.float16
58
  ).to("cuda").eval()
59
 
60
+ # For SDXL Image Generation
61
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # Set your SDXL model repository path via env variable
62
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
63
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
64
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
 
 
 
 
 
 
 
 
 
 
65
 
66
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
67
+ MODEL_ID_SD,
68
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
69
+ use_safetensors=True,
70
+ add_watermarker=False,
71
+ ).to(device)
72
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
73
+ if torch.cuda.is_available():
74
+ sd_pipe.text_encoder = sd_pipe.text_encoder.half()
75
+ if USE_TORCH_COMPILE:
76
+ sd_pipe.compile()
77
+ if ENABLE_CPU_OFFLOAD:
78
+ sd_pipe.enable_model_cpu_offload()
79
 
80
+ # For SDXL quality styles and LoRA options (used in the image-gen tab)
81
+ LORA_OPTIONS = {
82
+ "Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
83
+ "Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
84
+ "Photoshoot (camera/film)📸": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"),
85
+ "Clothing (hoodies/pant/shirts)👔": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"),
86
+ "Interior Architecture (house/hotel)🏠": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1δ.safetensors", "arch"),
87
+ "Fashion Product (wearing/usable)👜": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"),
88
+ "Minimalistic Image (minimal/detailed)🏞️": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"),
89
+ "Modern Clothing (trend/new)👕": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"),
90
+ "Animaliea (farm/wild)🫎": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"),
91
+ "Liquid Wallpaper (minimal/illustration)🖼️": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"),
92
+ "Canes Cars (realistic/futurecars)🚘": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"),
93
+ "Pencil Art (characteristic/creative)✏️": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"),
94
+ "Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
95
+ }
96
+ style_list = [
97
+ {
98
+ "name": "3840 x 2160",
99
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
100
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
101
+ },
102
+ {
103
+ "name": "2560 x 1440",
104
+ "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
105
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
106
+ },
107
+ {
108
+ "name": "HD+",
109
+ "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
110
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
111
+ },
112
+ {
113
+ "name": "Style Zero",
114
+ "prompt": "{prompt}",
115
+ "negative_prompt": "",
116
+ },
117
+ ]
118
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
119
+ DEFAULT_STYLE_NAME = "3840 x 2160"
120
+ STYLE_NAMES = list(styles.keys())
121
 
122
+ # --------- Utility Functions ---------
123
+ def text_to_speech(text: str, voice: str, output_file="output.mp3"):
124
+ """Convert text to speech using Edge TTS and save as MP3"""
125
+ async def run_tts():
126
+ communicate = edge_tts.Communicate(text, voice)
127
+ await communicate.save(output_file)
128
+ return output_file
129
+ return asyncio.run(run_tts())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ def clean_chat_history(chat_history):
132
+ """Remove non-string content from the chat history."""
133
+ return [msg for msg in chat_history if isinstance(msg, dict) and isinstance(msg.get("content"), str)]
 
 
 
 
 
 
 
134
 
135
  def save_image(img: Image.Image) -> str:
136
+ """Save a PIL image to a file with a unique filename."""
137
  unique_name = str(uuid.uuid4()) + ".png"
138
  img.save(unique_name)
139
  return unique_name
140
 
141
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
142
+ return random.randint(0, MAX_SEED) if randomize_seed else seed
 
 
143
 
144
+ def progress_bar_html(label: str) -> str:
145
+ """Return an HTML snippet for a progress bar."""
146
+ return f'''
147
+ <div style="display: flex; align-items: center;">
148
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
149
+ <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
150
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
151
+ </div>
152
+ </div>
153
+ <style>
154
+ @keyframes loading {{
155
+ 0% {{ transform: translateX(-100%); }}
156
+ 100% {{ transform: translateX(100%); }}
157
+ }}
158
+ </style>
159
+ '''
160
+
161
+ def downsample_video(video_path):
162
+ """Extract 10 evenly spaced frames from a video."""
163
+ vidcap = cv2.VideoCapture(video_path)
164
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
165
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
166
+ frames = []
167
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
168
+ for i in frame_indices:
169
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
170
+ success, image = vidcap.read()
171
+ if success:
172
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
173
+ pil_image = Image.fromarray(image)
174
+ timestamp = round(i / fps, 2)
175
+ frames.append((pil_image, timestamp))
176
+ vidcap.release()
177
+ return frames
178
+
179
+ def apply_style(style_name: str, positive: str, negative: str = ""):
180
+ """Apply a chosen quality style to the prompt."""
181
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
182
+ return p.replace("{prompt}", positive), n + negative
183
 
184
+ # --------- Tab 1: Chat Interface (Multimodal) ---------
185
+ def chat_generate(input_dict: dict, chat_history: list,
186
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
187
+ temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  text = input_dict["text"]
189
  files = input_dict.get("files", [])
190
+ lower_text = text.strip().lower()
191
+
192
+ # If image generation command
193
+ if lower_text.startswith("@image"):
194
+ prompt = text[len("@image"):].strip()
195
+ yield progress_bar_html("Generating Image")
196
+ image_paths, used_seed = generate_image_fn(
197
+ prompt=prompt,
198
+ negative_prompt="",
199
+ use_negative_prompt=False,
200
+ seed=1,
201
+ width=1024,
202
+ height=1024,
203
+ guidance_scale=3,
204
+ num_inference_steps=25,
205
+ randomize_seed=True,
206
+ use_resolution_binning=True,
207
+ num_images=1,
208
+ )
209
+ yield gr.Image.update(value=image_paths[0])
210
+ return
211
+
212
+ # If video inference command
213
+ if lower_text.startswith("@video-infer"):
214
+ prompt = text[len("@video-infer"):].strip()
215
+ if files:
216
+ video_path = files[0]
217
+ frames = downsample_video(video_path)
218
+ messages = [
219
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
220
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
221
+ ]
222
+ for frame in frames:
223
+ image, timestamp = frame
224
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
225
+ image.save(image_path)
226
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
227
+ messages[1]["content"].append({"type": "image", "url": image_path})
228
+ else:
229
+ messages = [
230
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
231
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
232
+ ]
233
+ inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
234
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
235
+ generation_kwargs = {
236
+ **inputs,
237
+ "streamer": streamer,
238
+ "max_new_tokens": max_new_tokens,
239
+ "do_sample": True,
240
+ "temperature": temperature,
241
+ "top_p": top_p,
242
+ "top_k": top_k,
243
+ "repetition_penalty": repetition_penalty,
244
+ }
245
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
246
+ thread.start()
247
+ buffer = ""
248
+ yield progress_bar_html("Processing video with Qwen2VL")
249
+ for new_text in streamer:
250
+ buffer += new_text.replace("<|im_end|>", "")
251
+ time.sleep(0.01)
252
+ yield buffer
253
+ return
254
+
255
+ # Check for TTS command
256
  tts_prefix = "@tts"
257
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
258
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
266
  text = text.replace(tts_prefix, "").strip()
267
  conversation = clean_chat_history(chat_history)
268
  conversation.append({"role": "user", "content": text})
269
+
270
  if files:
271
+ # Handle multimodal chat with images
272
+ images = [load_image(f) for f in files]
 
 
 
 
273
  messages = [{
274
  "role": "user",
275
+ "content": [{"type": "image", "image": image} for image in images] + [{"type": "text", "text": text}]
 
 
 
276
  }]
277
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
278
+ inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
279
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
280
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
281
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
282
  thread.start()
 
283
  buffer = ""
284
+ yield progress_bar_html("Thinking...")
285
  for new_text in streamer:
286
+ buffer += new_text.replace("<|im_end|>", "")
 
287
  time.sleep(0.01)
288
  yield buffer
289
  else:
290
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
291
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
292
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
293
+ gr.Warning(f"Trimmed input as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
294
  input_ids = input_ids.to(model.device)
295
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
296
  generation_kwargs = {
 
306
  }
307
  t = Thread(target=model.generate, kwargs=generation_kwargs)
308
  t.start()
 
309
  outputs = []
310
+ yield progress_bar_html("Processing...")
311
  for new_text in streamer:
312
  outputs.append(new_text)
313
  yield "".join(outputs)
 
314
  final_response = "".join(outputs)
315
  yield final_response
 
316
  if is_tts and voice:
317
+ output_file = text_to_speech(final_response, voice)
318
+ yield gr.Audio.update(value=output_file)
319
 
320
+ # Helper function for image generation (used in chat @image branch)
321
+ @spaces.GPU(duration=60, enable_queue=True)
322
+ def generate_image_fn(prompt: str, negative_prompt: str = "", use_negative_prompt: bool = False,
323
+ seed: int = 1, width: int = 1024, height: int = 1024,
324
+ guidance_scale: float = 3, num_inference_steps: int = 25,
325
+ randomize_seed: bool = False, use_resolution_binning: bool = True,
326
+ num_images: int = 1, progress=None):
327
+ seed = int(randomize_seed_fn(seed, randomize_seed))
328
+ generator = torch.Generator(device=device).manual_seed(seed)
329
+ options = {
330
+ "prompt": [prompt] * num_images,
331
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
332
+ "width": width,
333
+ "height": height,
334
+ "guidance_scale": guidance_scale,
335
+ "num_inference_steps": num_inference_steps,
336
+ "generator": generator,
337
+ "output_type": "pil",
338
+ }
339
+ if use_resolution_binning:
340
+ options["use_resolution_binning"] = True
341
+
342
+ images = []
343
+ for i in range(0, num_images, BATCH_SIZE):
344
+ batch_options = options.copy()
345
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
346
+ if batch_options.get("negative_prompt") is not None:
347
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
348
+ if device.type == "cuda":
349
+ with torch.autocast("cuda", dtype=torch.float16):
350
+ outputs = sd_pipe(**batch_options)
351
+ else:
352
+ outputs = sd_pipe(**batch_options)
353
+ images.extend(outputs.images)
354
+ image_paths = [save_image(img) for img in images]
355
+ return image_paths, seed
356
+
357
+ # --------- Tab 2: SDXL Image Generation ---------
358
+ @spaces.GPU(duration=180, enable_queue=True)
359
+ def sdxl_generate(prompt: str, negative_prompt: str = "", use_negative_prompt: bool = True,
360
+ seed: int = 0, width: int = 1024, height: int = 1024, guidance_scale: float = 3,
361
+ randomize_seed: bool = False, style_name: str = DEFAULT_STYLE_NAME,
362
+ lora_model: str = "Realism (face/character)👦🏻", progress=None):
363
+ seed = int(randomize_seed_fn(seed, randomize_seed))
364
+ positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt)
365
+ if not use_negative_prompt:
366
+ effective_negative_prompt = ""
367
+ model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
368
+ # Set the adapter for the current generation
369
+ sd_pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
370
+ sd_pipe.set_adapters(adapter_name)
371
+ images = sd_pipe(
372
+ prompt=positive_prompt,
373
+ negative_prompt=effective_negative_prompt,
374
+ width=width,
375
+ height=height,
376
+ guidance_scale=guidance_scale,
377
+ num_inference_steps=20,
378
+ num_images_per_prompt=1,
379
+ cross_attention_kwargs={"scale": 0.65},
380
+ output_type="pil",
381
+ ).images
382
+ image_paths = [save_image(img) for img in images]
383
+ return image_paths, seed
384
+
385
+ # --------- Tab 3: Qwen2VL OCR & Text Generation ---------
386
+ def qwen2vl_ocr_textgen(prompt: str, image_file):
387
+ if image_file is None:
388
+ return "Please upload an image."
389
+ # Load the image
390
+ image = load_image(image_file)
391
+ messages = [
392
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
393
+ {"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}
394
+ ]
395
+ inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True,
396
+ return_dict=True, return_tensors="pt").to("cuda")
397
+ outputs = model_m.generate(
398
+ **inputs,
399
+ max_new_tokens=1024,
400
+ do_sample=True,
401
+ temperature=0.6,
402
+ top_p=0.9,
403
+ top_k=50,
404
+ repetition_penalty=1.2,
405
+ )
406
+ response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
407
+ return response
408
+
409
+ # --------- Building the Gradio Interface with Tabs ---------
410
+ with gr.Blocks(title="Combined Demo") as demo:
411
+ gr.Markdown("# Combined Demo: Chat, SDXL Image Gen & Qwen2VL OCR/TextGen")
412
+ with gr.Tabs():
413
+ # --- Tab 1: Chat Interface ---
414
+ with gr.Tab("Chat Interface"):
415
+ chat_interface = gr.ChatInterface(
416
+ fn=chat_generate,
417
+ additional_inputs=[
418
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
419
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
420
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
421
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
422
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
423
+ ],
424
+ examples=[
425
+ ["Write the Python Program for Array Rotation"],
426
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
427
+ [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
428
+ ["@image Chocolate dripping from a donut"],
429
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
430
+ ],
431
+ cache_examples=False,
432
+ type="messages",
433
+ description="Use commands like **@image**, **@video-infer**, **@tts1**, or plain text.",
434
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple",
435
+ placeholder="Type your query (e.g., @tts1 for TTS, @image for image gen, etc.)"),
436
+ stop_btn="Stop Generation",
437
+ multimodal=True,
438
+ )
439
+ # --- Tab 2: SDXL Image Generation ---
440
+ with gr.Tab("SDXL Gen Image"):
441
+ with gr.Row():
442
+ prompt_in = gr.Textbox(label="Prompt", placeholder="Enter prompt for image generation")
443
+ negative_prompt_in = gr.Textbox(label="Negative prompt", placeholder="Enter negative prompt", lines=2)
444
+ with gr.Row():
445
+ seed_in = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
446
+ randomize_in = gr.Checkbox(label="Randomize seed", value=True)
447
+ with gr.Row():
448
+ width_in = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
449
+ height_in = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
450
+ guidance_in = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0)
451
+ style_in = gr.Radio(choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, label="Quality Style")
452
+ lora_in = gr.Dropdown(choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻", label="LoRA Selection")
453
+ run_button_img = gr.Button("Generate Image")
454
+ output_gallery = gr.Gallery(label="Generated Image", columns=1, preview=True)
455
+ seed_output = gr.Number(label="Seed used")
456
+ run_button_img.click(fn=sdxl_generate,
457
+ inputs=[prompt_in, negative_prompt_in, randomize_in, seed_in, width_in, height_in, guidance_in, randomize_in, style_in, lora_in],
458
+ outputs=[output_gallery, seed_output])
459
+ # --- Tab 3: Qwen2VL OCR & Text Generation ---
460
+ with gr.Tab("Qwen2VL OCR/TextGen"):
461
+ with gr.Row():
462
+ qwen_prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt for OCR / text generation")
463
+ qwen_image = gr.Image(label="Upload Image", type="filepath")
464
+ run_button_qwen = gr.Button("Run Qwen2VL")
465
+ qwen_output = gr.Textbox(label="Output")
466
+ run_button_qwen.click(fn=qwen2vl_ocr_textgen, inputs=[qwen_prompt, qwen_image], outputs=qwen_output)
467
 
468
  if __name__ == "__main__":
469
+ demo.queue(max_size=30).launch(share=True)