prithivMLmods commited on
Commit
f5e2b63
·
verified ·
1 Parent(s): 8419dc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +428 -736
app.py CHANGED
@@ -1,756 +1,448 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- import tempfile
8
- from threading import Thread
9
- import base64
10
- import shutil
11
- import re
12
-
13
  import gradio as gr
14
  import spaces
15
  import torch
 
 
 
 
 
 
 
 
 
16
  import numpy as np
17
- from PIL import Image
18
- import edge_tts
19
- import trimesh
20
- import soundfile as sf # New import for audio file reading
21
 
22
- import supervision as sv
23
- from ultralytics import YOLO as YOLODetector
24
- from huggingface_hub import hf_hub_download
 
25
 
26
- from transformers import (
27
- AutoModelForCausalLM,
28
- AutoTokenizer,
29
- TextIteratorStreamer,
30
- Qwen2VLForConditionalGeneration,
31
- AutoProcessor,
32
  )
33
- from transformers.image_utils import load_image
34
-
35
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
36
- from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
37
- from diffusers.utils import export_to_ply
38
-
39
- os.system('pip install backoff')
40
- # Global constants and helper functions
41
-
42
- MAX_SEED = np.iinfo(np.int32).max
43
-
44
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
45
- if randomize_seed:
46
- seed = random.randint(0, MAX_SEED)
47
- return seed
48
-
49
- def glb_to_data_url(glb_path: str) -> str:
50
- """
51
- Reads a GLB file from disk and returns a data URL with a base64 encoded representation.
52
- (Not used in this method.)
53
- """
54
- with open(glb_path, "rb") as f:
55
- data = f.read()
56
- b64_data = base64.b64encode(data).decode("utf-8")
57
- return f"data:model/gltf-binary;base64,{b64_data}"
58
-
59
- def progress_bar_html(label: str) -> str:
60
- """
61
- Returns an HTML snippet for a thin progress bar with a label.
62
- The progress bar is styled as a dark red animated bar.
63
- """
64
- return f'''
65
- <div style="display: flex; align-items: center;">
66
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
67
- <div style="width: 110px; height: 5px; background-color: #AFEEEE; border-radius: 2px; overflow: hidden;">
68
- <div style="width: 100%; height: 100%; background-color: #00FFFF; animation: loading 1.5s linear infinite;"></div>
69
- </div>
70
- </div>
71
- <style>
72
- @keyframes loading {{
73
- 0% {{ transform: translateX(-100%); }}
74
- 100% {{ transform: translateX(100%); }}
75
- }}
76
- </style>
77
- '''
78
-
79
- # Model class for Text-to-3D Generation (ShapE)
80
-
81
- class Model:
82
- def __init__(self):
83
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
- self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
85
- self.pipe.to(self.device)
86
- # Ensure the text encoder is in half precision to avoid dtype mismatches.
87
- if torch.cuda.is_available():
88
- try:
89
- self.pipe.text_encoder = self.pipe.text_encoder.half()
90
- except AttributeError:
91
- pass
92
-
93
- self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
94
- self.pipe_img.to(self.device)
95
- # Use getattr with a default value to avoid AttributeError if text_encoder is missing.
96
- if torch.cuda.is_available():
97
- text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
98
- if text_encoder_img is not None:
99
- self.pipe_img.text_encoder = text_encoder_img.half()
100
-
101
- def to_glb(self, ply_path: str) -> str:
102
- mesh = trimesh.load(ply_path)
103
- # Rotate the mesh for proper orientation
104
- rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
105
- mesh.apply_transform(rot)
106
- rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
107
- mesh.apply_transform(rot)
108
- mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
109
- mesh.export(mesh_path.name, file_type="glb")
110
- return mesh_path.name
111
-
112
- def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
113
- generator = torch.Generator(device=self.device).manual_seed(seed)
114
- images = self.pipe(
115
- prompt,
116
- generator=generator,
117
- guidance_scale=guidance_scale,
118
- num_inference_steps=num_steps,
119
- output_type="mesh",
120
- ).images
121
- ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
122
- export_to_ply(images[0], ply_path.name)
123
- return self.to_glb(ply_path.name)
124
-
125
- def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
126
- generator = torch.Generator(device=self.device).manual_seed(seed)
127
- images = self.pipe_img(
128
- image,
129
- generator=generator,
130
- guidance_scale=guidance_scale,
131
- num_inference_steps=num_steps,
132
- output_type="mesh",
133
- ).images
134
- ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
135
- export_to_ply(images[0], ply_path.name)
136
- return self.to_glb(ply_path.name)
137
-
138
- # New Tools for Web Functionality using DuckDuckGo and smolagents
139
-
140
- from typing import Any, Optional
141
- from smolagents.tools import Tool
142
- import duckduckgo_search
143
-
144
- class DuckDuckGoSearchTool(Tool):
145
- name = "web_search"
146
- description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."
147
- inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
148
- output_type = "string"
149
-
150
- def __init__(self, max_results=10, **kwargs):
151
- super().__init__()
152
- self.max_results = max_results
153
- try:
154
- from duckduckgo_search import DDGS
155
- except ImportError as e:
156
- raise ImportError(
157
- "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
158
- ) from e
159
- self.ddgs = DDGS(**kwargs)
160
-
161
- def forward(self, query: str) -> str:
162
- results = self.ddgs.text(query, max_results=self.max_results)
163
- if len(results) == 0:
164
- raise Exception("No results found! Try a less restrictive/shorter query.")
165
- postprocessed_results = [
166
- f"[{result['title']}]({result['href']})\n{result['body']}" for result in results
167
- ]
168
- return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
169
-
170
- class VisitWebpageTool(Tool):
171
- name = "visit_webpage"
172
- description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
173
- inputs = {'url': {'type': 'string', 'description': 'The url of the webpage to visit.'}}
174
- output_type = "string"
175
-
176
- def __init__(self, *args, **kwargs):
177
- self.is_initialized = False
178
-
179
- def forward(self, url: str) -> str:
180
- try:
181
- import requests
182
- from markdownify import markdownify
183
- from requests.exceptions import RequestException
184
-
185
- from smolagents.utils import truncate_content
186
- except ImportError as e:
187
- raise ImportError(
188
- "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
189
- ) from e
190
- try:
191
- # Send a GET request to the URL with a 20-second timeout
192
- response = requests.get(url, timeout=20)
193
- response.raise_for_status() # Raise an exception for bad status codes
194
-
195
- # Convert the HTML content to Markdown
196
- markdown_content = markdownify(response.text).strip()
197
-
198
- # Remove multiple line breaks
199
- markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
200
-
201
- return truncate_content(markdown_content, 10000)
202
-
203
- except requests.exceptions.Timeout:
204
- return "The request timed out. Please try again later or check the URL."
205
- except RequestException as e:
206
- return f"Error fetching the webpage: {str(e)}"
207
- except Exception as e:
208
- return f"An unexpected error occurred: {str(e)}"
209
-
210
- # rAgent Reasoning using Llama mode OpenAI
211
-
212
- from openai import OpenAI
213
-
214
- ACCESS_TOKEN = os.getenv("HF_TOKEN")
215
- ragent_client = OpenAI(
216
- base_url="https://api-inference.huggingface.co/v1/",
217
- api_key=ACCESS_TOKEN,
218
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- SYSTEM_PROMPT = """
221
-
222
- "You are an expert assistant who solves tasks using Python code. Follow these steps:\n"
223
- "1. **Thought**: Explain your reasoning and plan for solving the task.\n"
224
- "2. **Code**: Write Python code to implement your solution.\n"
225
- "3. **Observation**: Analyze the output of the code and summarize the results.\n"
226
- "4. **Final Answer**: Provide a concise conclusion or final result.\n\n"
227
- f"Task: {{task}}"
228
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  """
230
 
231
- def ragent_reasoning(prompt: str, history: list[dict], max_tokens: int = 2048, temperature: float = 0.7, top_p: float = 0.95):
232
- """
233
- Uses the Llama mode OpenAI model to perform a structured reasoning chain.
234
- """
235
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
236
- # Incorporate conversation history (if any)
237
- for msg in history:
238
- if msg.get("role") == "user":
239
- messages.append({"role": "user", "content": msg["content"]})
240
- elif msg.get("role") == "assistant":
241
- messages.append({"role": "assistant", "content": msg["content"]})
242
- messages.append({"role": "user", "content": prompt})
243
- response = ""
244
- stream = ragent_client.chat.completions.create(
245
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
246
- max_tokens=max_tokens,
247
- stream=True,
248
- temperature=temperature,
249
- top_p=top_p,
250
- messages=messages,
251
- )
252
- for message in stream:
253
- token = message.choices[0].delta.content
254
- response += token
255
- yield response
256
-
257
- # ------------------------------------------------------------------------------
258
- # New Phi-4 Multimodal Feature (Image & Audio)
259
- # ------------------------------------------------------------------------------
260
- # Define prompt structure for Phi-4
261
- phi4_user_prompt = '<|user|>'
262
- phi4_assistant_prompt = '<|assistant|>'
263
- phi4_prompt_suffix = '<|end|>'
264
-
265
- # Load Phi-4 multimodal model and processor using unique variable names
266
- phi4_model_path = "microsoft/Phi-4-multimodal-instruct"
267
- phi4_processor = AutoProcessor.from_pretrained(phi4_model_path, trust_remote_code=True)
268
- phi4_model = AutoModelForCausalLM.from_pretrained(
269
- phi4_model_path,
270
- device_map="auto",
271
- torch_dtype="auto",
272
- trust_remote_code=True,
273
- _attn_implementation="eager",
274
- )
275
 
276
- MAX_MAX_NEW_TOKENS = 2048
277
- DEFAULT_MAX_NEW_TOKENS = 1024
278
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
281
 
282
- # Load Models and Pipelines for Chat, Image, and Multimodal Processing
283
- # Load the text-only model and tokenizer (for pure text chat)
 
284
 
285
- model_id = "prithivMLmods/Ganymede-Llama-3.3-3B-Preview" #prithivMLmods/FastThink-0.5B-Tiny
286
- tokenizer = AutoTokenizer.from_pretrained(model_id)
287
- model = AutoModelForCausalLM.from_pretrained(
288
- model_id,
289
- device_map="auto",
290
- torch_dtype=torch.bfloat16,
291
- )
292
- model.eval()
293
-
294
- # Voices for text-to-speech
295
- TTS_VOICES = [
296
- "en-US-JennyNeural", # @tts1
297
- "en-US-GuyNeural", # @tts2
298
- ]
299
-
300
- # Load multimodal processor and model (e.g. for OCR and image processing)
301
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
302
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
303
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
304
- MODEL_ID,
305
- trust_remote_code=True,
306
- torch_dtype=torch.float16
307
- ).to("cuda").eval()
308
-
309
- # Asynchronous text-to-speech
310
-
311
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
312
- """Convert text to speech using Edge TTS and save as MP3"""
313
- communicate = edge_tts.Communicate(text, voice)
314
- await communicate.save(output_file)
315
- return output_file
316
-
317
- # Utility function to clean conversation history
318
-
319
- def clean_chat_history(chat_history):
320
- """
321
- Filter out any chat entries whose "content" is not a string.
322
- This helps prevent errors when concatenating previous messages.
323
- """
324
- cleaned = []
325
- for msg in chat_history:
326
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
327
- cleaned.append(msg)
328
- return cleaned
329
-
330
- # Stable Diffusion XL Pipeline for Image Generation
331
- # Model In Use : SG161222/RealVisXL_V5.0_Lightning
332
-
333
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
334
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
335
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
336
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
337
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
338
-
339
- sd_pipe = StableDiffusionXLPipeline.from_pretrained(
340
- MODEL_ID_SD,
341
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
342
- use_safetensors=True,
343
- add_watermarker=False,
344
- ).to(device)
345
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
346
-
347
- if torch.cuda.is_available():
348
- sd_pipe.text_encoder = sd_pipe.text_encoder.half()
349
-
350
- if USE_TORCH_COMPILE:
351
- sd_pipe.compile()
352
-
353
- if ENABLE_CPU_OFFLOAD:
354
- sd_pipe.enable_model_cpu_offload()
355
-
356
- def save_image(img: Image.Image) -> str:
357
- """Save a PIL image with a unique filename and return the path."""
358
- unique_name = str(uuid.uuid4()) + ".png"
359
- img.save(unique_name)
360
- return unique_name
361
-
362
- @spaces.GPU(duration=60, enable_queue=True)
363
- # SG161222/RealVisXL_V5.0_Lightning
364
- def generate_image_fn(
365
- prompt: str,
366
- negative_prompt: str = "",
367
- use_negative_prompt: bool = False,
368
- seed: int = 1,
369
- width: int = 1024,
370
- height: int = 1024,
371
- guidance_scale: float = 3,
372
- num_inference_steps: int = 25,
373
- randomize_seed: bool = False,
374
- use_resolution_binning: bool = True,
375
- num_images: int = 1,
376
- progress=gr.Progress(track_tqdm=True),
377
- ):
378
- """Generate images using the SDXL pipeline."""
379
- seed = int(randomize_seed_fn(seed, randomize_seed))
380
- generator = torch.Generator(device=device).manual_seed(seed)
381
-
382
- options = {
383
- "prompt": [prompt] * num_images,
384
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
385
- "width": width,
386
- "height": height,
387
- "guidance_scale": guidance_scale,
388
- "num_inference_steps": num_inference_steps,
389
- "generator": generator,
390
- "output_type": "pil",
391
- }
392
- if use_resolution_binning:
393
- options["use_resolution_binning"] = True
394
-
395
- images = []
396
- # Process in batches
397
- for i in range(0, num_images, BATCH_SIZE):
398
- batch_options = options.copy()
399
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
400
- if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
401
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
402
- if device.type == "cuda":
403
- with torch.autocast("cuda", dtype=torch.float16):
404
- outputs = sd_pipe(**batch_options)
405
- else:
406
- outputs = sd_pipe(**batch_options)
407
- images.extend(outputs.images)
408
- image_paths = [save_image(img) for img in images]
409
- return image_paths, seed
410
-
411
- # Text-to-3D Generation using the ShapE Pipeline
412
-
413
- @spaces.GPU(duration=120, enable_queue=True)
414
- def generate_3d_fn(
415
- prompt: str,
416
- seed: int = 1,
417
- guidance_scale: float = 15.0,
418
- num_steps: int = 64,
419
- randomize_seed: bool = False,
420
- ):
421
- """
422
- Generate a 3D model from text using the ShapE pipeline.
423
- Returns a tuple of (glb_file_path, used_seed).
424
- """
425
- seed = int(randomize_seed_fn(seed, randomize_seed))
426
- model3d = Model()
427
- glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
428
- return glb_path, seed
429
-
430
- # YOLO Object Detection Setup
431
- YOLO_MODEL_REPO = "strangerzonehf/Flux-Ultimate-LoRA-Collection"
432
- YOLO_CHECKPOINT_NAME = "images/demo.pt"
433
- yolo_model_path = hf_hub_download(repo_id=YOLO_MODEL_REPO, filename=YOLO_CHECKPOINT_NAME)
434
- yolo_detector = YOLODetector(yolo_model_path)
435
-
436
- def detect_objects(image: np.ndarray):
437
- """Runs object detection on the input image."""
438
- results = yolo_detector(image, verbose=False)[0]
439
- detections = sv.Detections.from_ultralytics(results).with_nms()
440
-
441
- box_annotator = sv.BoxAnnotator()
442
- label_annotator = sv.LabelAnnotator()
443
-
444
- annotated_image = image.copy()
445
- annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
446
- annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
447
 
448
- return Image.fromarray(annotated_image)
449
-
450
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @phi4 commands
451
-
452
- @spaces.GPU
453
- def generate(
454
- input_dict: dict,
455
- chat_history: list[dict],
456
- max_new_tokens: int = 1024,
457
- temperature: float = 0.6,
458
- top_p: float = 0.9,
459
- top_k: int = 50,
460
- repetition_penalty: float = 1.2,
461
- ):
462
- """
463
- Generates chatbot responses with support for multimodal input and special commands:
464
- - "@tts1" or "@tts2": triggers text-to-speech.
465
- - "@image": triggers image generation using the SDXL pipeline.
466
- - "@3d": triggers 3D model generation using the ShapE pipeline.
467
- - "@web": triggers a web search or webpage visit.
468
- - "@rAgent": initiates a reasoning chain using Llama mode.
469
- - "@yolo": triggers object detection using YOLO.
470
- - **"@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.**
471
- """
472
- text = input_dict["text"]
473
- files = input_dict.get("files", [])
474
-
475
- # --- 3D Generation branch ---
476
- if text.strip().lower().startswith("@3d"):
477
- prompt = text[len("@3d"):].strip()
478
- yield progress_bar_html("Processing 3D Mesh Generation")
479
- glb_path, used_seed = generate_3d_fn(
480
- prompt=prompt,
481
- seed=1,
482
- guidance_scale=15.0,
483
- num_steps=64,
484
- randomize_seed=True,
485
- )
486
- # Copy the GLB file to a static folder.
487
- yield progress_bar_html("Finalizing 3D Mesh Generation")
488
- static_folder = os.path.join(os.getcwd(), "static")
489
- if not os.path.exists(static_folder):
490
- os.makedirs(static_folder)
491
- new_filename = f"mesh_{uuid.uuid4()}.glb"
492
- new_filepath = os.path.join(static_folder, new_filename)
493
- shutil.copy(glb_path, new_filepath)
494
-
495
- yield gr.File(new_filepath)
496
- return
497
-
498
- # --- Image Generation branch ---
499
- if text.strip().lower().startswith("@image"):
500
- prompt = text[len("@image"):].strip()
501
- yield progress_bar_html("Generating Image")
502
- image_paths, used_seed = generate_image_fn(
503
- prompt=prompt,
504
- negative_prompt="",
505
- use_negative_prompt=False,
506
- seed=1,
507
- width=1024,
508
- height=1024,
509
- guidance_scale=3,
510
- num_inference_steps=25,
511
- randomize_seed=True,
512
- use_resolution_binning=True,
513
- num_images=1,
514
- )
515
- yield gr.Image(image_paths[0])
516
- return
517
-
518
- # --- Web Search/Visit branch ---
519
- if text.strip().lower().startswith("@web"):
520
- web_command = text[len("@web"):].strip()
521
- # If the command starts with "visit", then treat the rest as a URL
522
- if web_command.lower().startswith("visit"):
523
- url = web_command[len("visit"):].strip()
524
- yield progress_bar_html("Visiting Webpage")
525
- visitor = VisitWebpageTool()
526
- content = visitor.forward(url)
527
- yield content
528
- else:
529
- # Otherwise, treat the rest as a search query.
530
- query = web_command
531
- yield progress_bar_html("Performing Web Search")
532
- searcher = DuckDuckGoSearchTool()
533
- results = searcher.forward(query)
534
- yield results
535
- return
536
-
537
- # --- rAgent Reasoning branch ---
538
- if text.strip().lower().startswith("@ragent"):
539
- prompt = text[len("@ragent"):].strip()
540
- yield progress_bar_html("Processing Reasoning Chain")
541
- # Pass the current chat history (cleaned) to help inform the chain.
542
- for partial in ragent_reasoning(prompt, clean_chat_history(chat_history)):
543
- yield partial
544
- return
545
-
546
- # --- YOLO Object Detection branch ---
547
- if text.strip().lower().startswith("@yolo"):
548
- yield progress_bar_html("Performing Object Detection")
549
- if not files or len(files) == 0:
550
- yield "Error: Please attach an image for YOLO object detection."
551
- return
552
- # Use the first attached image
553
- input_file = files[0]
554
- try:
555
- if isinstance(input_file, str):
556
- pil_image = Image.open(input_file)
557
- else:
558
- pil_image = input_file
559
- except Exception as e:
560
- yield f"Error loading image: {str(e)}"
561
- return
562
- np_image = np.array(pil_image)
563
- result_img = detect_objects(np_image)
564
- yield gr.Image(result_img)
565
- return
566
-
567
- # --- Phi-4 Multimodal branch (Image/Audio) with Streaming ---
568
- if text.strip().lower().startswith("@phi4"):
569
- question = text[len("@phi4"):].strip()
570
- if not files:
571
- yield "Error: Please attach an image or audio file for @phi4 multimodal processing."
572
- return
573
- if not question:
574
- yield "Error: Please provide a question after @phi4."
575
- return
576
- # Determine input type (Image or Audio) from the first file
577
- input_file = files[0]
578
- try:
579
- # If file is already a PIL Image, treat as image
580
- if isinstance(input_file, Image.Image):
581
- input_type = "Image"
582
- file_for_phi4 = input_file
583
- else:
584
- # Try opening as image; if it fails, assume audio
585
- try:
586
- file_for_phi4 = Image.open(input_file)
587
- input_type = "Image"
588
- except Exception:
589
- input_type = "Audio"
590
- file_for_phi4 = input_file
591
- except Exception:
592
- input_type = "Audio"
593
- file_for_phi4 = input_file
594
-
595
- if input_type == "Image":
596
- phi4_prompt = f'{phi4_user_prompt}<|image_1|>{question}{phi4_prompt_suffix}{phi4_assistant_prompt}'
597
- inputs = phi4_processor(text=phi4_prompt, images=file_for_phi4, return_tensors='pt').to(phi4_model.device)
598
- elif input_type == "Audio":
599
- phi4_prompt = f'{phi4_user_prompt}<|audio_1|>{question}{phi4_prompt_suffix}{phi4_assistant_prompt}'
600
- audio, samplerate = sf.read(file_for_phi4)
601
- inputs = phi4_processor(text=phi4_prompt, audios=[(audio, samplerate)], return_tensors='pt').to(phi4_model.device)
602
- else:
603
- yield "Invalid file type for @phi4 multimodal processing."
604
- return
605
-
606
- # Initialize the streamer
607
- streamer = TextIteratorStreamer(phi4_processor, skip_prompt=True, skip_special_tokens=True)
608
-
609
- # Prepare generation kwargs
610
- generation_kwargs = {
611
- **inputs,
612
- "streamer": streamer,
613
- "max_new_tokens": 200,
614
- "num_logits_to_keep": 0,
615
- }
616
-
617
- # Start generation in a separate thread
618
- thread = Thread(target=phi4_model.generate, kwargs=generation_kwargs)
619
- thread.start()
620
-
621
- # Stream the response
622
- buffer = ""
623
- yield progress_bar_html("Processing Phi-4 Multimodal")
624
- for new_text in streamer:
625
- buffer += new_text
626
- time.sleep(0.01) # Small delay to simulate real-time streaming
627
- yield buffer
628
- return
629
-
630
- # --- Text and TTS branch ---
631
- tts_prefix = "@tts"
632
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
633
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
634
 
635
- if is_tts and voice_index:
636
- voice = TTS_VOICES[voice_index - 1]
637
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
638
- conversation = [{"role": "user", "content": text}]
639
- else:
640
- voice = None
641
- text = text.replace(tts_prefix, "").strip()
642
- conversation = clean_chat_history(chat_history)
643
- conversation.append({"role": "user", "content": text})
644
-
645
- if files:
646
- if len(files) > 1:
647
- images = [load_image(image) for image in files]
648
- elif len(files) == 1:
649
- images = [load_image(files[0])]
650
- else:
651
- images = []
652
- messages = [{
653
- "role": "user",
654
- "content": [
655
- *[{"type": "image", "image": image} for image in images],
656
- {"type": "text", "text": text},
657
- ]
658
- }]
659
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
660
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
661
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
662
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
663
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
664
- thread.start()
665
-
666
- buffer = ""
667
- yield progress_bar_html("Processing with Qwen2VL OCR")
668
- for new_text in streamer:
669
- buffer += new_text
670
- buffer = buffer.replace("<|im_end|>", "")
671
- time.sleep(0.01)
672
- yield buffer
673
- else:
674
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
675
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
676
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
677
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
678
- input_ids = input_ids.to(model.device)
679
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
680
- generation_kwargs = {
681
- "input_ids": input_ids,
682
- "streamer": streamer,
683
- "max_new_tokens": max_new_tokens,
684
- "do_sample": True,
685
- "top_p": top_p,
686
- "top_k": top_k,
687
- "temperature": temperature,
688
- "num_beams": 1,
689
- "repetition_penalty": repetition_penalty,
690
- }
691
- t = Thread(target=model.generate, kwargs=generation_kwargs)
692
- t.start()
693
-
694
- outputs = []
695
- yield progress_bar_html("Processing Chat Response")
696
- for new_text in streamer:
697
- outputs.append(new_text)
698
- yield "".join(outputs)
699
-
700
- final_response = "".join(outputs)
701
- yield final_response
702
-
703
- if is_tts and voice:
704
- output_file = asyncio.run(text_to_speech(final_response, voice))
705
- yield gr.Audio(output_file, autoplay=True)
706
-
707
- # Gradio Chat Interface Setup and Launch
708
-
709
- demo = gr.ChatInterface(
710
- fn=generate,
711
- additional_inputs=[
712
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
713
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
714
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
715
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
716
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
717
- ],
718
- examples=[
719
- [{"text": "@phi4 Transcribe the audio to text.", "files": ["examples/harvard.wav"]}],
720
- [{"text": "@phi4 Summarize the content", "files": ["examples/write.jpg"]}],
721
- [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
722
- [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
723
- ["@image Chocolate dripping from a donut"],
724
- ["@3d A birthday cupcake with cherry"],
725
- ["@image A drawing of an man made out of hamburger, blue sky background, soft pastel colors"],
726
- ["@tts2 What causes rainbows to form?"],
727
- [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
728
- [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
729
- ["@rAgent Explain how a binary search algorithm works."],
730
- ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
731
- ["@tts1 Explain Tower of Hanoi"],
732
- ["Python Program for Array Rotation"],
733
- ],
734
- cache_examples=False,
735
- type="messages",
736
- description="# **Agent Dino `@phi4 'prompt..', @image, etc..`**",
737
- fill_height=True,
738
- textbox=gr.MultimodalTextbox(
739
- label="Query Input",
740
- file_types=["image", "audio"],
741
- file_count="multiple",
742
- placeholder="‎ @tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, default [plain text]"
743
- ),
744
- stop_btn="Stop Generation",
745
- multimodal=True,
746
- )
747
 
748
- # Ensure the static folder exists
749
- if not os.path.exists("static"):
750
- os.makedirs("static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
 
752
- from fastapi.staticfiles import StaticFiles
753
- demo.app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
 
 
754
 
755
- if __name__ == "__main__":
756
- demo.queue(max_size=20).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from diffusers import AutoencoderKL, TCDScheduler
5
+ from diffusers.models.model_loading_utils import load_state_dict
6
+ from gradio_imageslider import ImageSlider
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from controlnet_union import ControlNetModel_Union
10
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
+
12
+ from PIL import Image, ImageDraw
13
  import numpy as np
 
 
 
 
14
 
15
+ config_file = hf_hub_download(
16
+ "xinsir/controlnet-union-sdxl-1.0",
17
+ filename="config_promax.json",
18
+ )
19
 
20
+ config = ControlNetModel_Union.load_config(config_file)
21
+ controlnet_model = ControlNetModel_Union.from_config(config)
22
+ model_file = hf_hub_download(
23
+ "xinsir/controlnet-union-sdxl-1.0",
24
+ filename="diffusion_pytorch_model_promax.safetensors",
 
25
  )
26
+ state_dict = load_state_dict(model_file)
27
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
28
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
+ model.to(device="cuda", dtype=torch.float16)
31
+
32
+ vae = AutoencoderKL.from_pretrained(
33
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
+ ).to("cuda")
35
+
36
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
37
+ "SG161222/RealVisXL_V5.0_Lightning",
38
+ torch_dtype=torch.float16,
39
+ vae=vae,
40
+ controlnet=model,
41
+ variant="fp16",
42
+ ).to("cuda")
43
+
44
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
+
46
+
47
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
48
+ """Checks if the image can be expanded based on the alignment."""
49
+ if alignment in ("Left", "Right") and source_width >= target_width:
50
+ return False
51
+ if alignment in ("Top", "Bottom") and source_height >= target_height:
52
+ return False
53
+ return True
54
+
55
+ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
56
+ target_size = (width, height)
57
+
58
+ # Calculate the scaling factor to fit the image within the target size
59
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
60
+ new_width = int(image.width * scale_factor)
61
+ new_height = int(image.height * scale_factor)
62
+
63
+ # Resize the source image to fit within target size
64
+ source = image.resize((new_width, new_height), Image.LANCZOS)
65
+
66
+ # Apply resize option using percentages
67
+ if resize_option == "Full":
68
+ resize_percentage = 100
69
+ elif resize_option == "50%":
70
+ resize_percentage = 50
71
+ elif resize_option == "33%":
72
+ resize_percentage = 33
73
+ elif resize_option == "25%":
74
+ resize_percentage = 25
75
+ else: # Custom
76
+ resize_percentage = custom_resize_percentage
77
+
78
+ # Calculate new dimensions based on percentage
79
+ resize_factor = resize_percentage / 100
80
+ new_width = int(source.width * resize_factor)
81
+ new_height = int(source.height * resize_factor)
82
+
83
+ # Ensure minimum size of 64 pixels
84
+ new_width = max(new_width, 64)
85
+ new_height = max(new_height, 64)
86
+
87
+ # Resize the image
88
+ source = source.resize((new_width, new_height), Image.LANCZOS)
89
+
90
+ # Calculate the overlap in pixels based on the percentage
91
+ overlap_x = int(new_width * (overlap_percentage / 100))
92
+ overlap_y = int(new_height * (overlap_percentage / 100))
93
+
94
+ # Ensure minimum overlap of 1 pixel
95
+ overlap_x = max(overlap_x, 1)
96
+ overlap_y = max(overlap_y, 1)
97
+
98
+ # Calculate margins based on alignment
99
+ if alignment == "Middle":
100
+ margin_x = (target_size[0] - new_width) // 2
101
+ margin_y = (target_size[1] - new_height) // 2
102
+ elif alignment == "Left":
103
+ margin_x = 0
104
+ margin_y = (target_size[1] - new_height) // 2
105
+ elif alignment == "Right":
106
+ margin_x = target_size[0] - new_width
107
+ margin_y = (target_size[1] - new_height) // 2
108
+ elif alignment == "Top":
109
+ margin_x = (target_size[0] - new_width) // 2
110
+ margin_y = 0
111
+ elif alignment == "Bottom":
112
+ margin_x = (target_size[0] - new_width) // 2
113
+ margin_y = target_size[1] - new_height
114
+
115
+ # Adjust margins to eliminate gaps
116
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
117
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
118
+
119
+ # Create a new background image and paste the resized source image
120
+ background = Image.new('RGB', target_size, (255, 255, 255))
121
+ background.paste(source, (margin_x, margin_y))
122
+
123
+ # Create the mask
124
+ mask = Image.new('L', target_size, 255)
125
+ mask_draw = ImageDraw.Draw(mask)
126
+
127
+ # Calculate overlap areas
128
+ white_gaps_patch = 2
129
+
130
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
131
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
132
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
133
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
134
+
135
+ if alignment == "Left":
136
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
137
+ elif alignment == "Right":
138
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
139
+ elif alignment == "Top":
140
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
141
+ elif alignment == "Bottom":
142
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
143
+
144
+
145
+ # Draw the mask
146
+ mask_draw.rectangle([
147
+ (left_overlap, top_overlap),
148
+ (right_overlap, bottom_overlap)
149
+ ], fill=0)
150
+
151
+ return background, mask
152
+
153
+ def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
154
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
155
+
156
+ # Create a preview image showing the mask
157
+ preview = background.copy().convert('RGBA')
158
+
159
+ # Create a semi-transparent red overlay
160
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
161
+
162
+ # Convert black pixels in the mask to semi-transparent red
163
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
164
+ red_mask.paste(red_overlay, (0, 0), mask)
165
+
166
+ # Overlay the red mask on the background
167
+ preview = Image.alpha_composite(preview, red_mask)
168
+
169
+ return preview
170
 
171
+ @spaces.GPU(duration=24)
172
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
173
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
174
+
175
+ if not can_expand(background.width, background.height, width, height, alignment):
176
+ alignment = "Middle"
177
+
178
+ cnet_image = background.copy()
179
+ cnet_image.paste(0, (0, 0), mask)
180
+
181
+ final_prompt = f"{prompt_input} , high quality, 4k"
182
+
183
+ (
184
+ prompt_embeds,
185
+ negative_prompt_embeds,
186
+ pooled_prompt_embeds,
187
+ negative_pooled_prompt_embeds,
188
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
189
+
190
+ for image in pipe(
191
+ prompt_embeds=prompt_embeds,
192
+ negative_prompt_embeds=negative_prompt_embeds,
193
+ pooled_prompt_embeds=pooled_prompt_embeds,
194
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
195
+ image=cnet_image,
196
+ num_inference_steps=num_inference_steps
197
+ ):
198
+ yield cnet_image, image
199
+
200
+ image = image.convert("RGBA")
201
+ cnet_image.paste(image, (0, 0), mask)
202
+
203
+ yield background, cnet_image
204
+
205
+ def clear_result():
206
+ """Clears the result ImageSlider."""
207
+ return gr.update(value=None)
208
+
209
+ def preload_presets(target_ratio, ui_width, ui_height):
210
+ """Updates the width and height sliders based on the selected aspect ratio."""
211
+ if target_ratio == "9:16":
212
+ changed_width = 720
213
+ changed_height = 1280
214
+ return changed_width, changed_height, gr.update()
215
+ elif target_ratio == "16:9":
216
+ changed_width = 1280
217
+ changed_height = 720
218
+ return changed_width, changed_height, gr.update()
219
+ elif target_ratio == "1:1":
220
+ changed_width = 1024
221
+ changed_height = 1024
222
+ return changed_width, changed_height, gr.update()
223
+ elif target_ratio == "Custom":
224
+ return ui_width, ui_height, gr.update(open=True)
225
+
226
+ def select_the_right_preset(user_width, user_height):
227
+ if user_width == 720 and user_height == 1280:
228
+ return "9:16"
229
+ elif user_width == 1280 and user_height == 720:
230
+ return "16:9"
231
+ elif user_width == 1024 and user_height == 1024:
232
+ return "1:1"
233
+ else:
234
+ return "Custom"
235
+
236
+ def toggle_custom_resize_slider(resize_option):
237
+ return gr.update(visible=(resize_option == "Custom"))
238
+
239
+ def update_history(new_image, history):
240
+ """Updates the history gallery with the new image."""
241
+ if history is None:
242
+ history = []
243
+ history.insert(0, new_image)
244
+ return history
245
+
246
+ css = """
247
+ .gradio-container {
248
+ width: 1200px !important;
249
+ }
250
  """
251
 
252
+ title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
253
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ with gr.Blocks(css=css) as demo:
256
+ with gr.Column():
257
+ gr.HTML(title)
258
+
259
+ with gr.Row():
260
+ with gr.Column():
261
+ input_image = gr.Image(
262
+ type="pil",
263
+ label="Input Image"
264
+ )
265
+
266
+ with gr.Row():
267
+ with gr.Column(scale=2):
268
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
269
+ with gr.Column(scale=1):
270
+ run_button = gr.Button("Generate")
271
+
272
+ with gr.Row():
273
+ target_ratio = gr.Radio(
274
+ label="Expected Ratio",
275
+ choices=["9:16", "16:9", "1:1", "Custom"],
276
+ value="9:16",
277
+ scale=2
278
+ )
279
+
280
+ alignment_dropdown = gr.Dropdown(
281
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
282
+ value="Middle",
283
+ label="Alignment"
284
+ )
285
+
286
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
287
+ with gr.Column():
288
+ with gr.Row():
289
+ width_slider = gr.Slider(
290
+ label="Target Width",
291
+ minimum=720,
292
+ maximum=1536,
293
+ step=8,
294
+ value=720, # Set a default value
295
+ )
296
+ height_slider = gr.Slider(
297
+ label="Target Height",
298
+ minimum=720,
299
+ maximum=1536,
300
+ step=8,
301
+ value=1280, # Set a default value
302
+ )
303
+
304
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
305
+ with gr.Group():
306
+ overlap_percentage = gr.Slider(
307
+ label="Mask overlap (%)",
308
+ minimum=1,
309
+ maximum=50,
310
+ value=10,
311
+ step=1
312
+ )
313
+ with gr.Row():
314
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
315
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
316
+ with gr.Row():
317
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
318
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
319
+ with gr.Row():
320
+ resize_option = gr.Radio(
321
+ label="Resize input image",
322
+ choices=["Full", "50%", "33%", "25%", "Custom"],
323
+ value="Full"
324
+ )
325
+ custom_resize_percentage = gr.Slider(
326
+ label="Custom resize (%)",
327
+ minimum=1,
328
+ maximum=100,
329
+ step=1,
330
+ value=50,
331
+ visible=False
332
+ )
333
+
334
+ with gr.Column():
335
+ preview_button = gr.Button("Preview alignment and mask")
336
+
337
+
338
+ gr.Examples(
339
+ examples=[
340
+ ["./examples/example_1.webp", 1280, 720, "Middle"],
341
+ ["./examples/example_2.jpg", 1440, 810, "Left"],
342
+ ["./examples/example_3.jpg", 1024, 1024, "Top"],
343
+ ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
344
+ ],
345
+ inputs=[input_image, width_slider, height_slider, alignment_dropdown],
346
+ )
347
+
348
+
349
+
350
+ with gr.Column():
351
+ result = ImageSlider(
352
+ interactive=False,
353
+ label="Generated Image",
354
+ )
355
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
356
+
357
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
358
+ preview_image = gr.Image(label="Preview")
359
 
360
+
361
 
362
+ def use_output_as_input(output_image):
363
+ """Sets the generated output as the new input image."""
364
+ return gr.update(value=output_image[1])
365
 
366
+ use_as_input_button.click(
367
+ fn=use_output_as_input,
368
+ inputs=[result],
369
+ outputs=[input_image]
370
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ target_ratio.change(
373
+ fn=preload_presets,
374
+ inputs=[target_ratio, width_slider, height_slider],
375
+ outputs=[width_slider, height_slider, settings_panel],
376
+ queue=False
377
+ )
378
+
379
+ width_slider.change(
380
+ fn=select_the_right_preset,
381
+ inputs=[width_slider, height_slider],
382
+ outputs=[target_ratio],
383
+ queue=False
384
+ )
385
+
386
+ height_slider.change(
387
+ fn=select_the_right_preset,
388
+ inputs=[width_slider, height_slider],
389
+ outputs=[target_ratio],
390
+ queue=False
391
+ )
392
+
393
+ resize_option.change(
394
+ fn=toggle_custom_resize_slider,
395
+ inputs=[resize_option],
396
+ outputs=[custom_resize_percentage],
397
+ queue=False
398
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
+ run_button.click( # Clear the result
401
+ fn=clear_result,
402
+ inputs=None,
403
+ outputs=result,
404
+ ).then( # Generate the new image
405
+ fn=infer,
406
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
407
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
408
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
409
+ outputs=result,
410
+ ).then( # Update the history gallery
411
+ fn=lambda x, history: update_history(x[1], history),
412
+ inputs=[result, history_gallery],
413
+ outputs=history_gallery,
414
+ ).then( # Show the "Use as Input Image" button
415
+ fn=lambda: gr.update(visible=True),
416
+ inputs=None,
417
+ outputs=use_as_input_button,
418
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
+ prompt_input.submit( # Clear the result
421
+ fn=clear_result,
422
+ inputs=None,
423
+ outputs=result,
424
+ ).then( # Generate the new image
425
+ fn=infer,
426
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
427
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
428
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
429
+ outputs=result,
430
+ ).then( # Update the history gallery
431
+ fn=lambda x, history: update_history(x[1], history),
432
+ inputs=[result, history_gallery],
433
+ outputs=history_gallery,
434
+ ).then( # Show the "Use as Input Image" button
435
+ fn=lambda: gr.update(visible=True),
436
+ inputs=None,
437
+ outputs=use_as_input_button,
438
+ )
439
 
440
+ preview_button.click(
441
+ fn=preview_image_and_mask,
442
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
443
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
444
+ outputs=preview_image,
445
+ queue=False
446
+ )
447
 
448
+ demo.queue(max_size=12).launch(share=False, show_error=True)