Deadmon commited on
Commit
85330fa
1 Parent(s): 444af28

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
4
+ from transformers import AutoFeatureExtractor
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from insightface.app import FaceAnalysis
9
+ from insightface.utils import face_align
10
+ import gradio as gr
11
+ import cv2
12
+ import os
13
+
14
+ # Model paths
15
+ model_paths = {
16
+ "Realistic Vision V4.0": "SG161222/Realistic_Vision_V4.0_noVAE",
17
+ "Pony Realism v21": snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl"),
18
+ "Cyber Realistic Pony v61": snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl"),
19
+ "Stallion Dreams Pony Realistic v1": snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
20
+ }
21
+ vae_model_path = "stabilityai/sd-vae-ft-mse"
22
+ image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
23
+ ip_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sd15.bin", repo_type="model")
24
+ ip_plus_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid-plusv2_sd15.bin", repo_type="model")
25
+
26
+ # Safety Checker Setup
27
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
28
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
29
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
30
+
31
+ device = "cuda"
32
+
33
+ # Define the scheduler
34
+ noise_scheduler = DDIMScheduler(
35
+ num_train_timesteps=1000,
36
+ beta_start=0.00085,
37
+ beta_end=0.012,
38
+ beta_schedule="scaled_linear",
39
+ clip_sample=False,
40
+ set_alpha_to_one=False,
41
+ steps_offset=1,
42
+ )
43
+ vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
44
+
45
+ # Face analysis setup
46
+ app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
47
+ app.prepare(ctx_id=0, det_size=(640, 640))
48
+
49
+ cv2.setNumThreads(1)
50
+
51
+ # Function to load the appropriate pipeline based on user selection
52
+ def load_model(model_choice):
53
+ model_path = model_paths[model_choice]
54
+ pipeline = StableDiffusionPipeline.from_pretrained(
55
+ model_path,
56
+ torch_dtype=torch.float16,
57
+ scheduler=noise_scheduler,
58
+ vae=vae,
59
+ feature_extractor=safety_feature_extractor,
60
+ safety_checker=None
61
+ ).to(device)
62
+
63
+ # Load the IP Adapter models
64
+ ip_model = IPAdapterFaceID(pipeline, ip_ckpt, device)
65
+ ip_model_plus = IPAdapterFaceIDPlus(pipeline, image_encoder_path, ip_plus_ckpt, device)
66
+
67
+ return pipeline, ip_model, ip_model_plus
68
+
69
+ # Gradio function to generate images
70
+ @spaces.GPU(enable_queue=True)
71
+ def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
72
+ pipeline, ip_model, ip_model_plus = load_model(model_choice)
73
+ faceid_all_embeds = []
74
+ first_iteration = True
75
+ for image in images:
76
+ face = cv2.imread(image)
77
+ faces = app.get(face)
78
+ faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
79
+ faceid_all_embeds.append(faceid_embed)
80
+ if first_iteration and preserve_face_structure:
81
+ face_image = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224)
82
+ first_iteration = False
83
+
84
+ average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
85
+ total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
86
+
87
+ if not preserve_face_structure:
88
+ image = ip_model.generate(
89
+ prompt=prompt,
90
+ negative_prompt=total_negative_prompt,
91
+ faceid_embeds=average_embedding,
92
+ scale=likeness_strength,
93
+ width=512,
94
+ height=512,
95
+ num_inference_steps=30
96
+ )
97
+ else:
98
+ image = ip_model_plus.generate(
99
+ prompt=prompt,
100
+ negative_prompt=total_negative_prompt,
101
+ faceid_embeds=average_embedding,
102
+ scale=likeness_strength,
103
+ face_image=face_image,
104
+ shortcut=True,
105
+ s_scale=face_strength,
106
+ width=512,
107
+ height=512,
108
+ num_inference_steps=30
109
+ )
110
+ return image
111
+
112
+ def change_style(style):
113
+ if style == "Photorealistic":
114
+ return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
115
+ else:
116
+ return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
117
+
118
+ def swap_to_gallery(images):
119
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
120
+
121
+ def remove_back_to_files():
122
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
123
+
124
+ css = '''
125
+ h1{margin-bottom: 0 !important}
126
+ footer{display:none !important}
127
+ '''
128
+
129
+ with gr.Blocks(css=css) as demo:
130
+ gr.Markdown("")
131
+ gr.Markdown("")
132
+ with gr.Row():
133
+ with gr.Column():
134
+ files = gr.Files(
135
+ label="Drag 1 or more photos of your face",
136
+ file_types=["image"]
137
+ )
138
+ uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
139
+ with gr.Column(visible=False) as clear_button:
140
+ remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
141
+ prompt = gr.Textbox(
142
+ label="Prompt",
143
+ info="Try something like 'a photo of a man/woman/person'",
144
+ placeholder="A photo of a [man/woman/person]..."
145
+ )
146
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
147
+ style = gr.Radio(
148
+ label="Generation type",
149
+ info="For stylized try prompts like 'a watercolor painting of a woman'",
150
+ choices=["Photorealistic", "Stylized"],
151
+ value="Photorealistic"
152
+ )
153
+ model_choice = gr.Dropdown(
154
+ label="Model Choice",
155
+ choices=list(model_paths.keys()),
156
+ value="Realistic Vision V4.0"
157
+ )
158
+ submit = gr.Button("Submit")
159
+ with gr.Accordion(open=False, label="Advanced Options"):
160
+ preserve = gr.Checkbox(
161
+ label="Preserve Face Structure",
162
+ info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.",
163
+ value=True
164
+ )
165
+ face_strength = gr.Slider(
166
+ label="Face Structure strength",
167
+ info="Only applied if preserve face structure is checked",
168
+ value=1.3,
169
+ step=0.1,
170
+ minimum=0,
171
+ maximum=3
172
+ )
173
+ likeness_strength = gr.Slider(label="Face Embed strength", value=1.0, step=0.1, minimum=0, maximum=5)
174
+ nfaa_negative_prompts = gr.Textbox(
175
+ label="Appended Negative Prompts",
176
+ info="Negative prompts to steer generations towards safe for all audiences outputs",
177
+ value="naked, bikini, skimpy, scanty, bare skin, lingerie, swimsuit, exposed, see-through"
178
+ )
179
+ with gr.Column():
180
+ gallery = gr.Gallery(label="Generated Images")
181
+ style.change(fn=change_style,
182
+ inputs=style,
183
+ outputs=[preserve, face_strength, likeness_strength])
184
+ files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
185
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
186
+ submit.click(
187
+ fn=generate_image,
188
+ inputs=[files, prompt, negative_prompt, preserve, face_strength, likeness_strength, nfaa_negative_prompts, model_choice],
189
+ outputs=gallery
190
+ )
191
+
192
+ gr.Markdown("")
193
+
194
+ demo.launch()