Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,332 +1,210 @@
|
|
1 |
-
"""
|
2 |
-
app.py
|
3 |
-
|
4 |
-
A unified Gradio chat application for Multimodal OCR Granite Vision.
|
5 |
-
Commands (enter these as a prefix in the text input):
|
6 |
-
- @rag: For retrieval‐augmented generation (e.g. PDFs or text-based queries).
|
7 |
-
- @granite: For image understanding.
|
8 |
-
- @video-infer: For video understanding (video is downsampled into frames).
|
9 |
-
|
10 |
-
The app uses gr.MultimodalTextbox to support text input together with file uploads.
|
11 |
-
"""
|
12 |
-
|
13 |
-
import os
|
14 |
-
import time
|
15 |
-
import uuid
|
16 |
-
import random
|
17 |
import spaces
|
18 |
-
import
|
19 |
-
from threading import Thread
|
20 |
-
from pathlib import Path
|
21 |
-
from datetime import datetime, timezone
|
22 |
-
|
23 |
import torch
|
24 |
-
import numpy as np
|
25 |
-
import cv2
|
26 |
from PIL import Image
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
AutoProcessor,
|
33 |
-
AutoModelForVision2Seq,
|
34 |
-
)
|
35 |
-
|
36 |
-
# ---------------------------
|
37 |
-
# Utility functions and setup
|
38 |
-
# ---------------------------
|
39 |
-
|
40 |
-
def get_device():
|
41 |
-
if torch.backends.mps.is_available():
|
42 |
-
return "mps" # mac GPU
|
43 |
-
elif torch.cuda.is_available():
|
44 |
-
return "cuda"
|
45 |
-
else:
|
46 |
-
return "cpu"
|
47 |
-
|
48 |
-
device = get_device()
|
49 |
-
|
50 |
-
logging.basicConfig(
|
51 |
-
level=logging.INFO,
|
52 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
53 |
-
)
|
54 |
-
|
55 |
-
def downsample_video(video_path):
|
56 |
-
"""
|
57 |
-
Downsamples the video into 10 evenly spaced frames.
|
58 |
-
Returns a list of (PIL Image, timestamp in seconds) tuples.
|
59 |
-
"""
|
60 |
-
vidcap = cv2.VideoCapture(video_path)
|
61 |
-
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
62 |
-
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
63 |
-
frames = []
|
64 |
-
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
|
65 |
-
for i in frame_indices:
|
66 |
-
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
67 |
-
success, image = vidcap.read()
|
68 |
-
if success:
|
69 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
70 |
-
pil_image = Image.fromarray(image)
|
71 |
-
timestamp = round(i / fps, 2)
|
72 |
-
frames.append((pil_image, timestamp))
|
73 |
-
vidcap.release()
|
74 |
-
return frames
|
75 |
-
|
76 |
-
# ---------------------------
|
77 |
-
# HF Embedding and LLM classes
|
78 |
-
# ---------------------------
|
79 |
-
|
80 |
-
class HFEmbedding:
|
81 |
-
def __init__(self, model_id: str):
|
82 |
-
self.model_name = model_id
|
83 |
-
logging.info(f"Loading embeddings model from: {self.model_name}")
|
84 |
-
# Using langchain_huggingface for embeddings
|
85 |
-
from langchain_huggingface import HuggingFaceEmbeddings # ensure installed
|
86 |
-
# For simplicity, force CPU (adjust if needed)
|
87 |
-
self.embeddings_service = HuggingFaceEmbeddings(
|
88 |
-
model_name=self.model_name,
|
89 |
-
model_kwargs={"device": "cpu"},
|
90 |
-
)
|
91 |
-
|
92 |
-
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
93 |
-
return self.embeddings_service.embed_documents(texts)
|
94 |
-
|
95 |
-
def embed_query(self, text: str) -> list[float]:
|
96 |
-
return self.embed_documents([text])[0]
|
97 |
-
|
98 |
-
class HFLLM:
|
99 |
-
def __init__(self, model_name: str):
|
100 |
-
self.device = device
|
101 |
-
self.model_name = model_name
|
102 |
-
logging.info("Loading HF language model...")
|
103 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
104 |
-
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
|
105 |
-
|
106 |
-
def generate(self, prompt: str) -> list:
|
107 |
-
# Tokenize prompt and generate text
|
108 |
-
model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
109 |
-
generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
|
110 |
-
generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
|
111 |
-
# Extract answer assuming a marker in the generated text
|
112 |
-
response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}]
|
113 |
-
return response
|
114 |
-
|
115 |
-
# ---------------------------
|
116 |
-
# LightRAG: Retrieval-Augmented Generation (Dummy)
|
117 |
-
# ---------------------------
|
118 |
-
|
119 |
-
class LightRAG:
|
120 |
-
def __init__(self, config: dict):
|
121 |
-
self.config = config
|
122 |
-
# Load generation and embedding models immediately (or lazy load as needed)
|
123 |
-
self.gen_model = HFLLM(config['generation_model_id'])
|
124 |
-
self._embedding_model = HFEmbedding(config['embedding_model_id'])
|
125 |
-
|
126 |
-
def search(self, query: str, top_n: int = 5) -> list:
|
127 |
-
# Dummy retrieval: In practice, integrate with a vector store
|
128 |
-
from langchain_core.documents import Document # ensure langchain_core is installed
|
129 |
-
dummy_doc = Document(
|
130 |
-
page_content="Dummy context for query: " + query,
|
131 |
-
metadata={"type": "text"}
|
132 |
-
)
|
133 |
-
return [dummy_doc]
|
134 |
-
|
135 |
-
def generate(self, query, context=None):
|
136 |
-
if context is None:
|
137 |
-
context = []
|
138 |
-
# Build prompt by concatenating retrieved context with the query.
|
139 |
-
prompt = ""
|
140 |
-
for doc in context:
|
141 |
-
prompt += doc.page_content + "\n"
|
142 |
-
prompt += "\nQuestion: " + query + "\nAnswer:"
|
143 |
-
results = self.gen_model.generate(prompt)
|
144 |
-
answer = results[0]["answer"]
|
145 |
-
return answer, prompt
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
}
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
)
|
229 |
-
decoded = processor.decode(output[0], skip_special_tokens=True)
|
230 |
-
parts = decoded.strip().split("<|assistant|>")
|
231 |
-
return parts[-1].strip()
|
232 |
-
|
233 |
-
# ---------------------------
|
234 |
-
# Unified generation function for ChatInterface
|
235 |
-
# ---------------------------
|
236 |
-
@spaces.GPU
|
237 |
-
def generate(input_dict: dict, chat_history: list[dict],
|
238 |
-
max_new_tokens: int, temperature: float,
|
239 |
-
top_p: float, top_k: int, repetition_penalty: float):
|
240 |
-
"""
|
241 |
-
Chat function that inspects the input text for special commands and routes:
|
242 |
-
- @rag: Uses the RAG pipeline.
|
243 |
-
- @granite: Uses Granite Vision for image understanding.
|
244 |
-
- @video-infer: Uses Granite Vision for video processing.
|
245 |
-
"""
|
246 |
-
text = input_dict["text"]
|
247 |
-
files = input_dict.get("files", [])
|
248 |
-
lower_text = text.strip().lower()
|
249 |
-
|
250 |
-
# Optionally yield a progress message
|
251 |
-
yield "Processing your request..."
|
252 |
-
time.sleep(1) # simulate processing delay
|
253 |
-
|
254 |
-
if lower_text.startswith("@rag"):
|
255 |
-
query = text[len("@rag"):].strip()
|
256 |
-
logging.info(f"@rag command: {query}")
|
257 |
-
context = light_rag.search(query)
|
258 |
-
answer, _ = light_rag.generate(query, context)
|
259 |
-
yield answer
|
260 |
-
|
261 |
-
elif lower_text.startswith("@granite"):
|
262 |
-
prompt_text = text[len("@granite"):].strip()
|
263 |
-
logging.info(f"@granite command: {prompt_text}")
|
264 |
-
if files:
|
265 |
-
# Expecting an image file (as a PIL image)
|
266 |
-
image = files[0]
|
267 |
-
answer = generate_granite(image, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
|
268 |
-
yield answer
|
269 |
-
else:
|
270 |
-
yield "No image provided for @granite command."
|
271 |
-
|
272 |
-
elif lower_text.startswith("@video-infer"):
|
273 |
-
prompt_text = text[len("@video-infer"):].strip()
|
274 |
-
logging.info(f"@video-infer command: {prompt_text}")
|
275 |
-
if files:
|
276 |
-
# Expecting a video file (the file path)
|
277 |
-
video_path = files[0]
|
278 |
-
answer = generate_video_infer(video_path, prompt_text, max_new_tokens, temperature, top_p, top_k, repetition_penalty)
|
279 |
-
yield answer
|
280 |
-
else:
|
281 |
-
yield "No video provided for @video-infer command."
|
282 |
-
|
283 |
-
else:
|
284 |
-
# Default behavior: use RAG pipeline for text query.
|
285 |
-
query = text.strip()
|
286 |
-
logging.info(f"Default text query: {query}")
|
287 |
-
context = light_rag.search(query)
|
288 |
-
answer, _ = light_rag.generate(query, context)
|
289 |
-
yield answer
|
290 |
-
|
291 |
-
# ---------------------------
|
292 |
-
# Gradio ChatInterface using MultimodalTextbox
|
293 |
-
# ---------------------------
|
294 |
-
|
295 |
-
demo = gr.ChatInterface(
|
296 |
-
fn=generate,
|
297 |
-
additional_inputs=[
|
298 |
-
gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=1024),
|
299 |
-
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
|
300 |
-
gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, step=0.1, value=0.85),
|
301 |
-
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
|
302 |
-
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.05),
|
303 |
-
],
|
304 |
-
examples=[
|
305 |
-
# Examples show how to use the command prefixes.
|
306 |
-
[{"text": "@rag What models are available in Watsonx?"}],
|
307 |
-
[{"text": "@granite Describe the image", "files": [str(Path("examples") / "sample_image.png")]}],
|
308 |
-
[{"text": "@video-infer Summarize the event in the video", "files": [str(Path("examples") / "sample_video.mp4")]}],
|
309 |
-
],
|
310 |
-
cache_examples=False,
|
311 |
-
type="messages",
|
312 |
-
description=(
|
313 |
-
"# **Multimodal OCR Granite Vision**\n\n"
|
314 |
-
"Enter a command in the text input (with optional file uploads) using one of the following prefixes:\n\n"
|
315 |
-
"- **@rag**: For retrieval-augmented generation (e.g. PDFs, documents).\n"
|
316 |
-
"- **@granite**: For image understanding using Granite Vision.\n"
|
317 |
-
"- **@video-infer**: For video understanding (video is downsampled into frames).\n\n"
|
318 |
-
"For example:\n```\n@rag What is the revenue trend?\n```\n```\n@granite Describe this image\n```\n```\n@video-infer Summarize the event in this video\n```"
|
319 |
-
),
|
320 |
-
fill_height=True,
|
321 |
-
textbox=gr.MultimodalTextbox(
|
322 |
-
label="Query Input",
|
323 |
-
file_types=["image", "video", "pdf"],
|
324 |
-
file_count="multiple",
|
325 |
-
placeholder="@rag, @granite, or @video-infer followed by your prompt"
|
326 |
-
),
|
327 |
-
stop_btn="Stop Generation",
|
328 |
-
multimodal=True,
|
329 |
-
)
|
330 |
|
331 |
if __name__ == "__main__":
|
332 |
-
demo.queue(max_size=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
3 |
import torch
|
|
|
|
|
4 |
from PIL import Image
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
import random
|
7 |
+
import uuid
|
8 |
+
from typing import Tuple
|
9 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
def save_image(img):
|
12 |
+
unique_name = str(uuid.uuid4()) + ".png"
|
13 |
+
img.save(unique_name)
|
14 |
+
return unique_name
|
15 |
+
|
16 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
17 |
+
if randomize_seed:
|
18 |
+
seed = random.randint(0, MAX_SEED)
|
19 |
+
return seed
|
20 |
+
|
21 |
+
MAX_SEED = np.iinfo(np.int32).max
|
22 |
+
|
23 |
+
if not torch.cuda.is_available():
|
24 |
+
DESCRIPTIONz += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
|
25 |
+
|
26 |
+
base_model = "black-forest-labs/FLUX.1-dev"
|
27 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
28 |
+
|
29 |
+
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
|
30 |
+
trigger_word = "Super Realism" # Leave trigger_word blank if not used.
|
31 |
+
|
32 |
+
pipe.load_lora_weights(lora_repo)
|
33 |
+
pipe.to("cuda")
|
34 |
+
|
35 |
+
style_list = [
|
36 |
+
{
|
37 |
+
"name": "3840 x 2160",
|
38 |
+
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "2560 x 1440",
|
42 |
+
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"name": "HD+",
|
46 |
+
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Style Zero",
|
50 |
+
"prompt": "{prompt}",
|
51 |
+
},
|
52 |
+
]
|
53 |
+
|
54 |
+
styles = {k["name"]: k["prompt"] for k in style_list}
|
55 |
+
|
56 |
+
DEFAULT_STYLE_NAME = "3840 x 2160"
|
57 |
+
STYLE_NAMES = list(styles.keys())
|
58 |
+
|
59 |
+
def apply_style(style_name: str, positive: str) -> str:
|
60 |
+
return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
|
61 |
+
|
62 |
+
@spaces.GPU(duration=60, enable_queue=True)
|
63 |
+
def generate(
|
64 |
+
prompt: str,
|
65 |
+
seed: int = 0,
|
66 |
+
width: int = 1024,
|
67 |
+
height: int = 1024,
|
68 |
+
guidance_scale: float = 3,
|
69 |
+
randomize_seed: bool = False,
|
70 |
+
style_name: str = DEFAULT_STYLE_NAME,
|
71 |
+
progress=gr.Progress(track_tqdm=True),
|
72 |
+
):
|
73 |
+
seed = int(randomize_seed_fn(seed, randomize_seed))
|
74 |
+
|
75 |
+
positive_prompt = apply_style(style_name, prompt)
|
76 |
+
|
77 |
+
if trigger_word:
|
78 |
+
positive_prompt = f"{trigger_word} {positive_prompt}"
|
79 |
+
|
80 |
+
images = pipe(
|
81 |
+
prompt=positive_prompt,
|
82 |
+
width=width,
|
83 |
+
height=height,
|
84 |
+
guidance_scale=guidance_scale,
|
85 |
+
num_inference_steps=28,
|
86 |
+
num_images_per_prompt=1,
|
87 |
+
output_type="pil",
|
88 |
+
).images
|
89 |
+
image_paths = [save_image(img) for img in images]
|
90 |
+
print(image_paths)
|
91 |
+
return image_paths, seed
|
92 |
+
|
93 |
+
examples = [
|
94 |
+
"Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
|
95 |
+
"Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
|
96 |
+
"Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
|
97 |
+
"Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
|
98 |
+
]
|
99 |
+
|
100 |
+
css = '''
|
101 |
+
.gradio-container{max-width: 888px !important}
|
102 |
+
h1{text-align:center}
|
103 |
+
footer {
|
104 |
+
visibility: hidden
|
105 |
}
|
106 |
+
.submit-btn {
|
107 |
+
background-color: #e34949 !important;
|
108 |
+
color: white !important;
|
109 |
+
}
|
110 |
+
.submit-btn:hover {
|
111 |
+
background-color: #ff3b3b !important;
|
112 |
+
}
|
113 |
+
'''
|
114 |
+
|
115 |
+
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
116 |
+
with gr.Row():
|
117 |
+
with gr.Column(scale=1):
|
118 |
+
prompt = gr.Text(
|
119 |
+
label="Prompt",
|
120 |
+
show_label=False,
|
121 |
+
max_lines=1,
|
122 |
+
placeholder="Enter your prompt",
|
123 |
+
container=False,
|
124 |
+
)
|
125 |
+
run_button = gr.Button("Generate as ( 768 x 1024 )🤗", scale=0, elem_classes="submit-btn")
|
126 |
+
|
127 |
+
with gr.Accordion("Advanced options", open=True, visible=True):
|
128 |
+
seed = gr.Slider(
|
129 |
+
label="Seed",
|
130 |
+
minimum=0,
|
131 |
+
maximum=MAX_SEED,
|
132 |
+
step=1,
|
133 |
+
value=0,
|
134 |
+
visible=True
|
135 |
+
)
|
136 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
137 |
+
|
138 |
+
with gr.Row(visible=True):
|
139 |
+
width = gr.Slider(
|
140 |
+
label="Width",
|
141 |
+
minimum=512,
|
142 |
+
maximum=2048,
|
143 |
+
step=64,
|
144 |
+
value=768,
|
145 |
+
)
|
146 |
+
height = gr.Slider(
|
147 |
+
label="Height",
|
148 |
+
minimum=512,
|
149 |
+
maximum=2048,
|
150 |
+
step=64,
|
151 |
+
value=1024,
|
152 |
+
)
|
153 |
+
|
154 |
+
with gr.Row():
|
155 |
+
guidance_scale = gr.Slider(
|
156 |
+
label="Guidance Scale",
|
157 |
+
minimum=0.1,
|
158 |
+
maximum=20.0,
|
159 |
+
step=0.1,
|
160 |
+
value=3.0,
|
161 |
+
)
|
162 |
+
num_inference_steps = gr.Slider(
|
163 |
+
label="Number of inference steps",
|
164 |
+
minimum=1,
|
165 |
+
maximum=40,
|
166 |
+
step=1,
|
167 |
+
value=28,
|
168 |
+
)
|
169 |
+
|
170 |
+
style_selection = gr.Radio(
|
171 |
+
show_label=True,
|
172 |
+
container=True,
|
173 |
+
interactive=True,
|
174 |
+
choices=STYLE_NAMES,
|
175 |
+
value=DEFAULT_STYLE_NAME,
|
176 |
+
label="Quality Style",
|
177 |
+
)
|
178 |
+
|
179 |
+
with gr.Column(scale=2):
|
180 |
+
result = gr.Gallery(label="Result", columns=1, show_label=False)
|
181 |
+
|
182 |
+
gr.Examples(
|
183 |
+
examples=examples,
|
184 |
+
inputs=prompt,
|
185 |
+
outputs=[result, seed],
|
186 |
+
fn=generate,
|
187 |
+
cache_examples=False,
|
188 |
+
)
|
189 |
+
|
190 |
+
gr.on(
|
191 |
+
triggers=[
|
192 |
+
prompt.submit,
|
193 |
+
run_button.click,
|
194 |
+
],
|
195 |
+
fn=generate,
|
196 |
+
inputs=[
|
197 |
+
prompt,
|
198 |
+
seed,
|
199 |
+
width,
|
200 |
+
height,
|
201 |
+
guidance_scale,
|
202 |
+
randomize_seed,
|
203 |
+
style_selection,
|
204 |
+
],
|
205 |
+
outputs=[result, seed],
|
206 |
+
api_name="run",
|
207 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
if __name__ == "__main__":
|
210 |
+
demo.queue(max_size=40).launch()
|