prithivMLmods commited on
Commit
78be7e8
·
verified ·
1 Parent(s): f4bb0af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -242
app.py CHANGED
@@ -4,6 +4,7 @@ import uuid
4
  import time
5
  import asyncio
6
  from threading import Thread
 
7
 
8
  import gradio as gr
9
  import spaces
@@ -12,6 +13,7 @@ import numpy as np
12
  from PIL import Image
13
  import cv2
14
 
 
15
  from transformers import (
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
@@ -20,179 +22,26 @@ from transformers import (
20
  AutoProcessor,
21
  )
22
  from transformers.image_utils import load_image
23
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
24
 
25
  # ---------------------------
26
- # Global Settings & Utilities
27
  # ---------------------------
28
-
29
  MAX_MAX_NEW_TOKENS = 2048
30
  DEFAULT_MAX_NEW_TOKENS = 1024
31
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
-
34
- def save_image(img: Image.Image) -> str:
35
- """Save a PIL image with a unique filename and return the path."""
36
- unique_name = str(uuid.uuid4()) + ".png"
37
- img.save(unique_name)
38
- return unique_name
39
-
40
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
41
- MAX_SEED = np.iinfo(np.int32).max
42
- if randomize_seed:
43
- seed = random.randint(0, MAX_SEED)
44
- return seed
45
-
46
- def progress_bar_html(label: str) -> str:
47
- """Returns an HTML snippet for a thin progress bar with a label."""
48
- return f'''
49
- <div style="display: flex; align-items: center;">
50
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
51
- <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
52
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
53
- </div>
54
- </div>
55
- <style>
56
- @keyframes loading {{
57
- 0% {{ transform: translateX(-100%); }}
58
- 100% {{ transform: translateX(100%); }}
59
- }}
60
- </style>
61
- '''
62
-
63
- # Helper function for the chat interface
64
- def apply_chat_template_for_text(conversation, add_generation_prompt=True):
65
- """
66
- Concatenates a conversation (list of dict with keys "role" and "content")
67
- into a single string prompt. If add_generation_prompt is True, appends "assistant:".
68
- """
69
- prompt = ""
70
- for msg in conversation:
71
- prompt += f"{msg['role']}: {msg['content']}\n"
72
- if add_generation_prompt:
73
- prompt += "assistant:"
74
- return prompt
75
-
76
- def clean_chat_history(chat_history):
77
- """
78
- Filter out any chat entries whose "content" is not a string.
79
- """
80
- cleaned = []
81
- for msg in chat_history:
82
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
83
- cleaned.append(msg)
84
- return cleaned
85
 
86
  # ---------------------------
87
- # 1. Chat Interface Tab
88
  # ---------------------------
89
- # Uses a text-only model: DeepHermes-3-Llama-3-3B-Preview-abliterated
90
 
91
- model_id_text = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
92
- tokenizer = AutoTokenizer.from_pretrained(model_id_text)
93
- model = AutoModelForCausalLM.from_pretrained(
94
- model_id_text,
95
- device_map="auto",
96
- torch_dtype=torch.bfloat16,
97
- )
98
- model.eval()
99
-
100
- @spaces.GPU
101
- def chat_generate(input_text: str, chat_history: list, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
102
- """
103
- Chat generation using a text-only model.
104
- """
105
- # Prepare conversation by cleaning history and appending the new user message.
106
- conversation = clean_chat_history(chat_history)
107
- conversation.append({"role": "user", "content": input_text})
108
-
109
- # Instead of tokenizer.apply_chat_template, we use our helper to generate a prompt.
110
- prompt_text = apply_chat_template_for_text(conversation, add_generation_prompt=True)
111
- input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids
112
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
113
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
114
- input_ids = input_ids.to(model.device)
115
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
116
- generation_kwargs = {
117
- "input_ids": input_ids,
118
- "streamer": streamer,
119
- "max_new_tokens": max_new_tokens,
120
- "do_sample": True,
121
- "top_p": top_p,
122
- "top_k": top_k,
123
- "temperature": temperature,
124
- "num_beams": 1,
125
- "repetition_penalty": repetition_penalty,
126
- }
127
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
128
- thread.start()
129
- outputs = []
130
- # Collect the generated text from the streamer.
131
- for new_text in streamer:
132
- outputs.append(new_text)
133
- final_response = "".join(outputs)
134
- # Append assistant reply to conversation.
135
- updated_history = conversation + [{"role": "assistant", "content": final_response}]
136
- return final_response, updated_history
137
-
138
- # ---------------------------
139
- # 2. Qwen 2 VL OCR Tab
140
- # ---------------------------
141
- # Uses Qwen2VL OCR model for multimodal input (text + image)
142
-
143
- MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
144
- processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
145
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
146
- MODEL_ID_QWEN,
147
- trust_remote_code=True,
148
- torch_dtype=torch.float16
149
- ).to("cuda").eval()
150
-
151
- @spaces.GPU
152
- def generate_qwen_ocr(input_text: str, image):
153
- """
154
- Uses the Qwen2VL OCR model to process an image along with text.
155
- """
156
- if image is None:
157
- return "No image provided."
158
- # Build message with system and user content.
159
- messages = [
160
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
161
- {"role": "user", "content": [{"type": "text", "text": input_text}, {"type": "image", "image": image}]}
162
- ]
163
- # Use the processor's chat template.
164
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
165
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to("cuda")
166
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
167
- generation_kwargs = {
168
- **inputs,
169
- "streamer": streamer,
170
- "max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
171
- "do_sample": True,
172
- "temperature": 0.6,
173
- "top_p": 0.9,
174
- "top_k": 50,
175
- "repetition_penalty": 1.2,
176
- }
177
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
178
- thread.start()
179
- outputs = []
180
- for new_text in streamer:
181
- outputs.append(new_text.replace("<|im_end|>", ""))
182
- final_response = "".join(outputs)
183
- return final_response
184
-
185
- # ---------------------------
186
- # 3. Image Gen LoRA Tab
187
- # ---------------------------
188
- # Uses the SDXL pipeline with LoRA options.
189
-
190
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # set your SDXL model path via env variable
191
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
192
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
193
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
194
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
195
 
 
196
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
197
  MODEL_ID_SD,
198
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -202,12 +51,26 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
202
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
203
  if torch.cuda.is_available():
204
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
 
 
 
 
205
  if USE_TORCH_COMPILE:
206
  sd_pipe.compile()
207
  if ENABLE_CPU_OFFLOAD:
208
  sd_pipe.enable_model_cpu_offload()
209
 
210
- # LoRA options dictionary.
 
 
 
 
 
 
 
 
 
 
211
  LORA_OPTIONS = {
212
  "Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
213
  "Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
@@ -224,7 +87,6 @@ LORA_OPTIONS = {
224
  "Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
225
  }
226
 
227
- # Style options.
228
  style_list = [
229
  {
230
  "name": "3840 x 2160",
@@ -251,102 +113,198 @@ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
251
  DEFAULT_STYLE_NAME = "3840 x 2160"
252
  STYLE_NAMES = list(styles.keys())
253
 
254
- def apply_style(style_name: str, positive: str, negative: str = ""):
255
  if style_name in styles:
256
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
257
  else:
258
  p, n = styles[DEFAULT_STYLE_NAME]
259
- return p.replace("{prompt}", positive), n + (negative if negative else "")
260
 
261
- @spaces.GPU
262
- def generate_image_lora(prompt: str, negative_prompt: str, use_negative_prompt: bool, seed: int, width: int, height: int, guidance_scale: float, randomize_seed: bool, style_name: str, lora_model: str):
 
 
 
 
 
 
 
 
 
 
 
 
263
  seed = int(randomize_seed_fn(seed, randomize_seed))
264
  positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt)
265
  if not use_negative_prompt:
266
  effective_negative_prompt = ""
267
- # Set the desired LoRA adapter.
268
  model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
269
- sd_pipe.set_adapters(adapter_name)
270
- # Generate image(s)
271
- options = {
272
- "prompt": [positive_prompt],
273
- "negative_prompt": [effective_negative_prompt],
274
- "width": width,
275
- "height": height,
276
- "guidance_scale": guidance_scale,
277
- "num_inference_steps": 20,
278
- "num_images_per_prompt": 1,
279
- "cross_attention_kwargs": {"scale": 0.65},
280
- "output_type": "pil",
281
- }
282
- outputs = sd_pipe(**options)
283
- images = outputs.images
284
- image_paths = [save_image(img) for img in images]
285
  return image_paths, seed
286
 
287
  # ---------------------------
288
- # Build Gradio Interface with Three Tabs
289
  # ---------------------------
290
- with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo:
291
- gr.Markdown("## Multi-Functional Demo: Chat Interface | Qwen 2 VL OCR | Image Gen LoRA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- with gr.Tabs():
294
- # Tab 1: Chat Interface
295
- with gr.Tab("Chat Interface"):
296
- chat_output = gr.Chatbot(label="Chat Conversation")
297
- with gr.Row():
298
- chat_inp = gr.Textbox(label="Enter your message", placeholder="Type your message here...", lines=2)
299
- send_btn = gr.Button("Send")
300
- with gr.Row():
301
- max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
302
- temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
303
- top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
304
- top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
305
- rep_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
306
- state = gr.State([])
307
 
308
- def chat_step(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty):
309
- response, updated_history = chat_generate(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty)
310
- return updated_history, updated_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- send_btn.click(chat_step,
313
- inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider],
314
- outputs=[chat_output, state])
315
- chat_inp.submit(chat_step,
316
- inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider],
317
- outputs=[chat_output, state])
318
-
319
- # Tab 2: Qwen 2 VL OCR
320
- with gr.Tab("Qwen 2 VL OCR"):
321
- gr.Markdown("Upload an image and enter a prompt. The model will return OCR/extraction or descriptive text from the image.")
322
- ocr_inp = gr.Textbox(label="Enter prompt", placeholder="Describe what you want to extract...", lines=2)
323
- image_inp = gr.Image(label="Upload Image", type="pil")
324
- ocr_output = gr.Textbox(label="Output", placeholder="Model output will appear here...", lines=5)
325
- ocr_btn = gr.Button("Run Qwen 2 VL OCR")
326
- ocr_btn.click(generate_qwen_ocr, inputs=[ocr_inp, image_inp], outputs=ocr_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- # Tab 3: Image Gen LoRA
329
- with gr.Tab("Image Gen LoRA"):
330
- gr.Markdown("Generate images with SDXL using various LoRA models and quality styles.")
331
- with gr.Row():
332
- prompt_img = gr.Textbox(label="Prompt", placeholder="Enter prompt for image generation...", lines=2)
333
- negative_prompt_img = gr.Textbox(label="Negative Prompt", placeholder="(optional) negative prompt", lines=2)
334
- use_neg_checkbox = gr.Checkbox(label="Use Negative Prompt", value=True)
335
- with gr.Row():
336
- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=0)
337
- randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
338
- with gr.Row():
339
- width_slider = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
340
- height_slider = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
341
- guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0)
342
- style_radio = gr.Radio(label="Quality Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
343
- lora_dropdown = gr.Dropdown(label="LoRA Selection", choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻")
344
- img_output = gr.Gallery(label="Generated Images", columns=1, preview=True)
345
- seed_output = gr.Number(label="Used Seed")
346
- run_img_btn = gr.Button("Generate Image")
347
- run_img_btn.click(generate_image_lora,
348
- inputs=[prompt_img, negative_prompt_img, use_neg_checkbox, seed_slider, width_slider, height_slider, guidance_slider, randomize_seed_checkbox, style_radio, lora_dropdown],
349
- outputs=[img_output, seed_output])
350
-
351
  if __name__ == "__main__":
352
  demo.queue(max_size=20).launch(share=True)
 
4
  import time
5
  import asyncio
6
  from threading import Thread
7
+ from typing import Tuple
8
 
9
  import gradio as gr
10
  import spaces
 
13
  from PIL import Image
14
  import cv2
15
 
16
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
17
  from transformers import (
18
  AutoModelForCausalLM,
19
  AutoTokenizer,
 
22
  AutoProcessor,
23
  )
24
  from transformers.image_utils import load_image
 
25
 
26
  # ---------------------------
27
+ # Global Settings and Devices
28
  # ---------------------------
 
29
  MAX_MAX_NEW_TOKENS = 2048
30
  DEFAULT_MAX_NEW_TOKENS = 1024
31
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # ---------------------------
36
+ # IMAGE GEN LO_RA TAB: SDXL Gen with LoRA Options
37
  # ---------------------------
 
38
 
39
+ # Load the SDXL pipeline
40
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # Path from env variable
41
+ if MODEL_ID_SD is None:
42
+ MODEL_ID_SD = "SG161222/RealVisXL_V4.0_Lightning" # default fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Load SDXL pipeline (use GPU if available)
45
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
46
  MODEL_ID_SD,
47
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
51
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
52
  if torch.cuda.is_available():
53
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
54
+
55
+ # Optional: compile or offload if desired
56
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
57
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
58
  if USE_TORCH_COMPILE:
59
  sd_pipe.compile()
60
  if ENABLE_CPU_OFFLOAD:
61
  sd_pipe.enable_model_cpu_offload()
62
 
63
+ def save_image(img: Image.Image) -> str:
64
+ unique_name = str(uuid.uuid4()) + ".png"
65
+ img.save(unique_name)
66
+ return unique_name
67
+
68
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
69
+ if randomize_seed:
70
+ seed = random.randint(0, MAX_SEED)
71
+ return seed
72
+
73
+ # LoRA options and style definitions
74
  LORA_OPTIONS = {
75
  "Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
76
  "Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
 
87
  "Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
88
  }
89
 
 
90
  style_list = [
91
  {
92
  "name": "3840 x 2160",
 
113
  DEFAULT_STYLE_NAME = "3840 x 2160"
114
  STYLE_NAMES = list(styles.keys())
115
 
116
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
117
  if style_name in styles:
118
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
119
  else:
120
  p, n = styles[DEFAULT_STYLE_NAME]
121
+ return p.replace("{prompt}", positive), n + negative
122
 
123
+ @spaces.GPU(duration=180, enable_queue=True)
124
+ def generate_image_lora(
125
+ prompt: str,
126
+ negative_prompt: str = "",
127
+ use_negative_prompt: bool = True,
128
+ seed: int = 0,
129
+ width: int = 1024,
130
+ height: int = 1024,
131
+ guidance_scale: float = 3,
132
+ randomize_seed: bool = False,
133
+ style_name: str = DEFAULT_STYLE_NAME,
134
+ lora_model: str = "Realism (face/character)👦🏻",
135
+ progress=gr.Progress(track_tqdm=True),
136
+ ):
137
  seed = int(randomize_seed_fn(seed, randomize_seed))
138
  positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt)
139
  if not use_negative_prompt:
140
  effective_negative_prompt = ""
141
+ # Set LoRA adapter based on selection
142
  model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
143
+ sd_pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
144
+ sd_pipe.to(device)
145
+
146
+ outputs = sd_pipe(
147
+ prompt=positive_prompt,
148
+ negative_prompt=effective_negative_prompt,
149
+ width=width,
150
+ height=height,
151
+ guidance_scale=guidance_scale,
152
+ num_inference_steps=20,
153
+ num_images_per_prompt=1,
154
+ cross_attention_kwargs={"scale": 0.65},
155
+ output_type="pil",
156
+ )
157
+ image_paths = [save_image(img) for img in outputs.images]
 
158
  return image_paths, seed
159
 
160
  # ---------------------------
161
+ # Qwen 2 VL OCR TAB
162
  # ---------------------------
163
+ MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
164
+ processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
165
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
166
+ MODEL_ID_QWEN,
167
+ trust_remote_code=True,
168
+ torch_dtype=torch.float16
169
+ ).to("cuda" if torch.cuda.is_available() else "cpu").eval()
170
+
171
+ @spaces.GPU
172
+ def qwen2vl_ocr_generate(
173
+ prompt: str,
174
+ file: list,
175
+ max_new_tokens: int = 1024,
176
+ temperature: float = 0.6,
177
+ top_p: float = 0.9,
178
+ top_k: int = 50,
179
+ repetition_penalty: float = 1.2,
180
+ ):
181
+ # In this tab, we assume the user supplies an image (or multiple images) for OCR.
182
+ images = []
183
+ if file:
184
+ # load image(s) using the helper function
185
+ for f in file:
186
+ images.append(load_image(f))
187
+ else:
188
+ # If no image provided, use an empty list
189
+ images = []
190
+ # Build message content: We use a simple chat template with text and images.
191
+ messages = [{
192
+ "role": "user",
193
+ "content": [
194
+ *[{"type": "image", "image": image} for image in images],
195
+ {"type": "text", "text": prompt},
196
+ ]
197
+ }]
198
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
199
+ inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
200
+ # Use non-streaming generation for simplicity
201
+ output_ids = model_m.generate(
202
+ **inputs,
203
+ max_new_tokens=max_new_tokens,
204
+ do_sample=True,
205
+ temperature=temperature,
206
+ top_p=top_p,
207
+ top_k=top_k,
208
+ repetition_penalty=repetition_penalty,
209
+ )
210
+ final_response = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
211
+ return final_response
212
 
213
+ # ---------------------------
214
+ # CHAT INTERFACE TAB (Text-only)
215
+ # ---------------------------
216
+ # Load text-only model and tokenizer
217
+ model_id_text = "prithivMLmods/FastThink-0.5B-Tiny"
218
+ tokenizer = AutoTokenizer.from_pretrained(model_id_text)
219
+ model = AutoModelForCausalLM.from_pretrained(
220
+ model_id_text,
221
+ device_map="auto",
222
+ torch_dtype=torch.bfloat16,
223
+ )
224
+ model.eval()
 
 
225
 
226
+ def chat_generate(prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6,
227
+ top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
228
+ # For simplicity, use a basic generate without streaming.
229
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
230
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
231
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
232
+ input_ids = input_ids.to(model.device)
233
+ output_ids = model.generate(
234
+ input_ids=input_ids,
235
+ max_new_tokens=max_new_tokens,
236
+ do_sample=True,
237
+ temperature=temperature,
238
+ top_p=top_p,
239
+ top_k=top_k,
240
+ repetition_penalty=repetition_penalty,
241
+ )
242
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
243
+ return response
244
 
245
+ # ---------------------------
246
+ # GRADIO INTERFACE WITH TABS
247
+ # ---------------------------
248
+ with gr.Blocks(title="Multi-Modal Playground") as demo:
249
+ gr.Markdown("# Multi-Modal Playground")
250
+
251
+ with gr.Tab("Image Gen LoRA"):
252
+ gr.Markdown("## Generate Images using SDXL + LoRA")
253
+ with gr.Row():
254
+ prompt_img = gr.Textbox(label="Prompt", placeholder="Enter your image prompt here")
255
+ negative_prompt_img = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt (optional)", lines=2)
256
+ with gr.Row():
257
+ use_negative = gr.Checkbox(label="Use Negative Prompt", value=True)
258
+ seed_img = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
259
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
260
+ with gr.Row():
261
+ width_img = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
262
+ height_img = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
263
+ with gr.Row():
264
+ guidance_scale_img = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0)
265
+ with gr.Row():
266
+ style_selection = gr.Radio(choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, label="Quality Style")
267
+ lora_selection = gr.Dropdown(choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻", label="LoRA Selection")
268
+ run_img = gr.Button("Generate Image")
269
+ gallery = gr.Gallery(label="Generated Images", columns=1).style(full_width=True)
270
+ output_seed = gr.Number(label="Seed Used")
271
+ run_img.click(
272
+ generate_image_lora,
273
+ inputs=[prompt_img, negative_prompt_img, use_negative, seed_img, width_img, height_img, guidance_scale_img,
274
+ randomize_seed, style_selection, lora_selection],
275
+ outputs=[gallery, output_seed]
276
+ )
277
+
278
+ with gr.Tab("Qwen 2 VL OCR"):
279
+ gr.Markdown("## Extract and Generate Text from Images (OCR)")
280
+ with gr.Row():
281
+ prompt_ocr = gr.Textbox(label="OCR Prompt", placeholder="Enter instructions for OCR/text extraction")
282
+ file_ocr = gr.File(label="Upload Image", file_types=["image"], file_count="multiple")
283
+ run_ocr = gr.Button("Run OCR")
284
+ output_ocr = gr.Textbox(label="OCR Output")
285
+ run_ocr.click(
286
+ qwen2vl_ocr_generate,
287
+ inputs=[prompt_ocr, file_ocr],
288
+ outputs=output_ocr
289
+ )
290
+
291
+ with gr.Tab("Chat Interface"):
292
+ gr.Markdown("## Chat with the Text-Only Model")
293
+ chat_input = gr.Textbox(label="Enter your message", placeholder="Say something...")
294
+ max_tokens_chat = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
295
+ temperature_chat = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
296
+ top_p_chat = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
297
+ top_k_chat = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
298
+ rep_penalty_chat = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
299
+ run_chat = gr.Button("Send")
300
+ chat_output = gr.Textbox(label="Response")
301
+ run_chat.click(
302
+ chat_generate,
303
+ inputs=[chat_input, max_tokens_chat, temperature_chat, top_p_chat, top_k_chat, rep_penalty_chat],
304
+ outputs=chat_output
305
+ )
306
+
307
+ gr.Markdown("**Adjust parameters in each tab as needed.**")
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  if __name__ == "__main__":
310
  demo.queue(max_size=20).launch(share=True)