gokaygokay
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
|
7 |
+
from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
|
8 |
+
import gradio as gr
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
|
11 |
+
|
12 |
+
import subprocess
|
13 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
14 |
+
|
15 |
+
# Download the model files
|
16 |
+
ckpt_dir = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
|
17 |
+
|
18 |
+
# Load the models
|
19 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
|
20 |
+
|
21 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
22 |
+
ckpt_dir,
|
23 |
+
vae=vae,
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
use_safetensors=True,
|
26 |
+
variant="fp16"
|
27 |
+
)
|
28 |
+
pipe = pipe.to("cuda")
|
29 |
+
|
30 |
+
# Define samplers
|
31 |
+
samplers = {
|
32 |
+
"Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
|
33 |
+
"DPM++ 2M": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True),
|
34 |
+
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
|
35 |
+
}
|
36 |
+
|
37 |
+
DEFAULT_POSITIVE_PREFIX = "score_9, score_8_up, score_7_up, BREAK,"
|
38 |
+
DEFAULT_POSITIVE_SUFFIX = "(masterpiece), best quality, very aesthetic, perfect face"
|
39 |
+
DEFAULT_NEGATIVE_PREFIX = "score_1, score_2, score_3, text"
|
40 |
+
DEFAULT_NEGATIVE_SUFFIX = "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
|
41 |
+
|
42 |
+
# Initialize Florence model
|
43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
+
florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
|
45 |
+
florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
|
46 |
+
|
47 |
+
# Prompt Enhancer
|
48 |
+
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
|
49 |
+
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
|
50 |
+
|
51 |
+
# Florence caption function
|
52 |
+
def florence_caption(image):
|
53 |
+
# Convert image to PIL if it's not already
|
54 |
+
if not isinstance(image, Image.Image):
|
55 |
+
image = Image.fromarray(image)
|
56 |
+
|
57 |
+
inputs = florence_processor(text="<DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
|
58 |
+
generated_ids = florence_model.generate(
|
59 |
+
input_ids=inputs["input_ids"],
|
60 |
+
pixel_values=inputs["pixel_values"],
|
61 |
+
max_new_tokens=1024,
|
62 |
+
early_stopping=False,
|
63 |
+
do_sample=False,
|
64 |
+
num_beams=3,
|
65 |
+
)
|
66 |
+
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
67 |
+
parsed_answer = florence_processor.post_process_generation(
|
68 |
+
generated_text,
|
69 |
+
task="<DETAILED_CAPTION>",
|
70 |
+
image_size=(image.width, image.height)
|
71 |
+
)
|
72 |
+
return parsed_answer["<DETAILED_CAPTION>"]
|
73 |
+
|
74 |
+
# Prompt Enhancer function
|
75 |
+
def enhance_prompt(input_prompt, model_choice):
|
76 |
+
if model_choice == "Medium":
|
77 |
+
result = enhancer_medium("Enhance the description: " + input_prompt)
|
78 |
+
enhanced_text = result[0]['summary_text']
|
79 |
+
else: # Long
|
80 |
+
result = enhancer_long("Enhance the description: " + input_prompt)
|
81 |
+
enhanced_text = result[0]['summary_text']
|
82 |
+
|
83 |
+
return enhanced_text
|
84 |
+
|
85 |
+
@spaces.GPU(duration=120)
|
86 |
+
def generate_image(additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer, input_image=None, progress=gr.Progress(track_tqdm=True)):
|
87 |
+
if use_random_seed:
|
88 |
+
seed = random.randint(0, 2**32 - 1)
|
89 |
+
else:
|
90 |
+
seed = int(seed) # Ensure seed is an integer
|
91 |
+
|
92 |
+
# Set the scheduler based on the selected sampler
|
93 |
+
pipe.scheduler = samplers[sampler]
|
94 |
+
|
95 |
+
# Set clip skip
|
96 |
+
pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
|
97 |
+
|
98 |
+
# Start with the default positive prompt prefix
|
99 |
+
full_positive_prompt = DEFAULT_POSITIVE_PREFIX
|
100 |
+
|
101 |
+
# Add Florence-2 caption if enabled and image is provided
|
102 |
+
if use_florence2 and input_image is not None:
|
103 |
+
florence2_caption = florence_caption(input_image)
|
104 |
+
florence2_caption = florence2_caption.lower().replace('.', ',')
|
105 |
+
additional_positive_prompt = f"{florence2_caption}, {additional_positive_prompt}" if additional_positive_prompt else florence2_caption
|
106 |
+
|
107 |
+
# Enhance only the additional positive prompt if enhancers are enabled
|
108 |
+
if additional_positive_prompt:
|
109 |
+
enhanced_prompt = additional_positive_prompt
|
110 |
+
if use_medium_enhancer:
|
111 |
+
medium_enhanced = enhance_prompt(enhanced_prompt, "Medium")
|
112 |
+
medium_enhanced = medium_enhanced.lower().replace('.', ',')
|
113 |
+
enhanced_prompt = f"{enhanced_prompt}, {medium_enhanced}"
|
114 |
+
if use_long_enhancer:
|
115 |
+
long_enhanced = enhance_prompt(enhanced_prompt, "Long")
|
116 |
+
long_enhanced = long_enhanced.lower().replace('.', ',')
|
117 |
+
enhanced_prompt = f"{enhanced_prompt}, {long_enhanced}"
|
118 |
+
full_positive_prompt += f"{enhanced_prompt}"
|
119 |
+
|
120 |
+
# Add the default positive suffix
|
121 |
+
full_positive_prompt += f", {DEFAULT_POSITIVE_SUFFIX}"
|
122 |
+
|
123 |
+
# Combine default negative prompt with additional negative prompt
|
124 |
+
full_negative_prompt = f"{DEFAULT_NEGATIVE_PREFIX}, {additional_negative_prompt}, {DEFAULT_NEGATIVE_SUFFIX}" if additional_negative_prompt else f"{DEFAULT_NEGATIVE_PREFIX}, {DEFAULT_NEGATIVE_SUFFIX}"
|
125 |
+
|
126 |
+
try:
|
127 |
+
image = pipe(
|
128 |
+
prompt=full_positive_prompt,
|
129 |
+
negative_prompt=full_negative_prompt,
|
130 |
+
height=height,
|
131 |
+
width=width,
|
132 |
+
num_inference_steps=num_inference_steps,
|
133 |
+
guidance_scale=guidance_scale,
|
134 |
+
num_images_per_prompt=num_images_per_prompt,
|
135 |
+
generator=torch.Generator(pipe.device).manual_seed(seed)
|
136 |
+
).images
|
137 |
+
return image, seed, full_positive_prompt
|
138 |
+
except Exception as e:
|
139 |
+
print(f"Error during image generation: {str(e)}")
|
140 |
+
return None, seed, full_positive_prompt
|
141 |
+
|
142 |
+
# Gradio interface
|
143 |
+
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
|
144 |
+
gr.HTML("""
|
145 |
+
<h1 align="center">Pony Realism v21 SDXL - Text-to-Image Generation</h1>
|
146 |
+
<p align="center">
|
147 |
+
<a href="https://huggingface.co/John6666/pony-realism-v21main-sdxl/" target="_blank">[HF Model Page]</a>
|
148 |
+
<a href="https://civitai.com/models/372465/pony-realism" target="_blank">[civitai Model Page]</a>
|
149 |
+
<a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
|
150 |
+
<a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
|
151 |
+
<a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance" target="_blank">[Prompt Enhancer Medium]</a>
|
152 |
+
</p>
|
153 |
+
""")
|
154 |
+
|
155 |
+
with gr.Row():
|
156 |
+
with gr.Column(scale=1):
|
157 |
+
positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Add your positive prompt here")
|
158 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Add your negative prompt here")
|
159 |
+
|
160 |
+
with gr.Accordion("Advanced settings", open=False):
|
161 |
+
height = gr.Slider(512, 2048, 1024, step=64, label="Height")
|
162 |
+
width = gr.Slider(512, 2048, 1024, step=64, label="Width")
|
163 |
+
num_inference_steps = gr.Slider(20, 50, 30, step=1, label="Number of Inference Steps")
|
164 |
+
guidance_scale = gr.Slider(1, 20, 6, step=0.1, label="Guidance Scale")
|
165 |
+
num_images_per_prompt = gr.Slider(1, 4, 1, step=1, label="Number of images per prompt")
|
166 |
+
use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
|
167 |
+
seed = gr.Number(label="Seed", value=0, precision=0)
|
168 |
+
sampler = gr.Dropdown(label="Sampler", choices=list(samplers.keys()), value="DPM++ SDE Karras")
|
169 |
+
clip_skip = gr.Slider(1, 4, 2, step=1, label="Clip skip")
|
170 |
+
|
171 |
+
with gr.Accordion("Captioner and Enhancers", open=False):
|
172 |
+
input_image = gr.Image(label="Input Image for Florence-2 Captioner")
|
173 |
+
use_florence2 = gr.Checkbox(label="Use Florence-2 Captioner", value=False)
|
174 |
+
use_medium_enhancer = gr.Checkbox(label="Use Medium Prompt Enhancer", value=False)
|
175 |
+
use_long_enhancer = gr.Checkbox(label="Use Long Prompt Enhancer", value=False)
|
176 |
+
|
177 |
+
generate_btn = gr.Button("Generate Image")
|
178 |
+
|
179 |
+
with gr.Column(scale=1):
|
180 |
+
output_gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
|
181 |
+
seed_used = gr.Number(label="Seed Used")
|
182 |
+
full_prompt_used = gr.Textbox(label="Full Positive Prompt Used")
|
183 |
+
|
184 |
+
generate_btn.click(
|
185 |
+
fn=generate_image,
|
186 |
+
inputs=[
|
187 |
+
positive_prompt, negative_prompt, height, width, num_inference_steps,
|
188 |
+
guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
|
189 |
+
clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer, input_image
|
190 |
+
],
|
191 |
+
outputs=[output_gallery, seed_used, full_prompt_used]
|
192 |
+
)
|
193 |
+
|
194 |
+
demo.launch(debug=True)
|