Deadmon commited on
Commit
02269e1
·
verified ·
1 Parent(s): 85330fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -99
app.py CHANGED
@@ -1,36 +1,28 @@
1
  import torch
2
  import spaces
3
- from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
4
  from transformers import AutoFeatureExtractor
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
7
- from huggingface_hub import hf_hub_download, snapshot_download
8
  from insightface.app import FaceAnalysis
9
  from insightface.utils import face_align
10
  import gradio as gr
11
  import cv2
12
  import os
13
 
14
- # Model paths
15
- model_paths = {
16
- "Realistic Vision V4.0": "SG161222/Realistic_Vision_V4.0_noVAE",
17
- "Pony Realism v21": snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl"),
18
- "Cyber Realistic Pony v61": snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl"),
19
- "Stallion Dreams Pony Realistic v1": snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
20
- }
21
  vae_model_path = "stabilityai/sd-vae-ft-mse"
22
  image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
23
  ip_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sd15.bin", repo_type="model")
24
  ip_plus_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid-plusv2_sd15.bin", repo_type="model")
25
 
26
- # Safety Checker Setup
27
  safety_model_id = "CompVis/stable-diffusion-safety-checker"
28
  safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
29
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
30
 
31
  device = "cuda"
32
 
33
- # Define the scheduler
34
  noise_scheduler = DDIMScheduler(
35
  num_train_timesteps=1000,
36
  beta_start=0.00085,
@@ -41,35 +33,39 @@ noise_scheduler = DDIMScheduler(
41
  steps_offset=1,
42
  )
43
  vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Face analysis setup
46
  app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
47
  app.prepare(ctx_id=0, det_size=(640, 640))
48
 
49
  cv2.setNumThreads(1)
50
 
51
- # Function to load the appropriate pipeline based on user selection
52
- def load_model(model_choice):
53
- model_path = model_paths[model_choice]
54
- pipeline = StableDiffusionPipeline.from_pretrained(
55
- model_path,
56
- torch_dtype=torch.float16,
57
- scheduler=noise_scheduler,
58
- vae=vae,
59
- feature_extractor=safety_feature_extractor,
60
- safety_checker=None
61
- ).to(device)
62
-
63
- # Load the IP Adapter models
64
- ip_model = IPAdapterFaceID(pipeline, ip_ckpt, device)
65
- ip_model_plus = IPAdapterFaceIDPlus(pipeline, image_encoder_path, ip_plus_ckpt, device)
66
-
67
- return pipeline, ip_model, ip_model_plus
68
-
69
- # Gradio function to generate images
70
  @spaces.GPU(enable_queue=True)
71
  def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
72
- pipeline, ip_model, ip_model_plus = load_model(model_choice)
73
  faceid_all_embeds = []
74
  first_iteration = True
75
  for image in images:
@@ -77,43 +73,41 @@ def generate_image(images, prompt, negative_prompt, preserve_face_structure, fac
77
  faces = app.get(face)
78
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
79
  faceid_all_embeds.append(faceid_embed)
80
- if first_iteration and preserve_face_structure:
81
- face_image = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224)
82
  first_iteration = False
83
-
84
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
 
85
  total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
86
-
87
- if not preserve_face_structure:
 
 
 
 
 
 
 
 
88
  image = ip_model.generate(
89
- prompt=prompt,
90
- negative_prompt=total_negative_prompt,
91
- faceid_embeds=average_embedding,
92
- scale=likeness_strength,
93
- width=512,
94
- height=512,
95
- num_inference_steps=30
96
  )
97
  else:
 
98
  image = ip_model_plus.generate(
99
- prompt=prompt,
100
- negative_prompt=total_negative_prompt,
101
- faceid_embeds=average_embedding,
102
- scale=likeness_strength,
103
- face_image=face_image,
104
- shortcut=True,
105
- s_scale=face_strength,
106
- width=512,
107
- height=512,
108
- num_inference_steps=30
109
  )
 
110
  return image
111
 
112
  def change_style(style):
113
  if style == "Photorealistic":
114
- return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
115
  else:
116
- return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
117
 
118
  def swap_to_gallery(images):
119
  return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
@@ -132,63 +126,39 @@ with gr.Blocks(css=css) as demo:
132
  with gr.Row():
133
  with gr.Column():
134
  files = gr.Files(
135
- label="Drag 1 or more photos of your face",
136
- file_types=["image"]
137
- )
138
  uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
139
  with gr.Column(visible=False) as clear_button:
140
  remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
141
- prompt = gr.Textbox(
142
- label="Prompt",
143
- info="Try something like 'a photo of a man/woman/person'",
144
- placeholder="A photo of a [man/woman/person]..."
145
- )
146
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
147
- style = gr.Radio(
148
- label="Generation type",
149
- info="For stylized try prompts like 'a watercolor painting of a woman'",
150
- choices=["Photorealistic", "Stylized"],
151
- value="Photorealistic"
152
- )
153
  model_choice = gr.Dropdown(
 
154
  label="Model Choice",
155
- choices=list(model_paths.keys()),
156
- value="Realistic Vision V4.0"
157
  )
158
  submit = gr.Button("Submit")
159
  with gr.Accordion(open=False, label="Advanced Options"):
160
- preserve = gr.Checkbox(
161
- label="Preserve Face Structure",
162
- info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.",
163
- value=True
164
- )
165
- face_strength = gr.Slider(
166
- label="Face Structure strength",
167
- info="Only applied if preserve face structure is checked",
168
- value=1.3,
169
- step=0.1,
170
- minimum=0,
171
- maximum=3
172
- )
173
  likeness_strength = gr.Slider(label="Face Embed strength", value=1.0, step=0.1, minimum=0, maximum=5)
174
- nfaa_negative_prompts = gr.Textbox(
175
- label="Appended Negative Prompts",
176
- info="Negative prompts to steer generations towards safe for all audiences outputs",
177
- value="naked, bikini, skimpy, scanty, bare skin, lingerie, swimsuit, exposed, see-through"
178
- )
179
  with gr.Column():
180
  gallery = gr.Gallery(label="Generated Images")
181
  style.change(fn=change_style,
182
- inputs=style,
183
- outputs=[preserve, face_strength, likeness_strength])
184
  files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
185
  remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
186
- submit.click(
187
- fn=generate_image,
188
- inputs=[files, prompt, negative_prompt, preserve, face_strength, likeness_strength, nfaa_negative_prompts, model_choice],
189
- outputs=gallery
190
- )
191
 
192
  gr.Markdown("")
193
-
194
- demo.launch()
 
1
  import torch
2
  import spaces
3
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline
4
  from transformers import AutoFeatureExtractor
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
7
+ from huggingface_hub import hf_hub_download
8
  from insightface.app import FaceAnalysis
9
  from insightface.utils import face_align
10
  import gradio as gr
11
  import cv2
12
  import os
13
 
14
+ base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
 
 
 
 
 
 
15
  vae_model_path = "stabilityai/sd-vae-ft-mse"
16
  image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
17
  ip_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sd15.bin", repo_type="model")
18
  ip_plus_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid-plusv2_sd15.bin", repo_type="model")
19
 
 
20
  safety_model_id = "CompVis/stable-diffusion-safety-checker"
21
  safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
22
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
23
 
24
  device = "cuda"
25
 
 
26
  noise_scheduler = DDIMScheduler(
27
  num_train_timesteps=1000,
28
  beta_start=0.00085,
 
33
  steps_offset=1,
34
  )
35
  vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
36
+ pipe = StableDiffusionPipeline.from_pretrained(
37
+ base_model_path,
38
+ torch_dtype=torch.float16,
39
+ scheduler=noise_scheduler,
40
+ vae=vae,
41
+ feature_extractor=safety_feature_extractor,
42
+ safety_checker=None # <--- Disable safety checker
43
+ ).to(device)
44
+
45
+ ip_model = IPAdapterFaceID(pipe, ip_ckpt, device)
46
+ ip_model_plus = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_plus_ckpt, device)
47
 
 
48
  app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
49
  app.prepare(ctx_id=0, det_size=(640, 640))
50
 
51
  cv2.setNumThreads(1)
52
 
53
+ # Download the SDXL model files
54
+ ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
55
+ ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
56
+ ckpt_dir_stallion = snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
57
+
58
+ # Load the SDXL models
59
+ pipe_pony = StableDiffusionXLPipeline.from_pretrained(ckpt_dir_pony, torch_dtype=torch.float16)
60
+ pipe_cyber = StableDiffusionXLPipeline.from_pretrained(ckpt_dir_cyber, torch_dtype=torch.float16)
61
+ pipe_stallion = StableDiffusionXLPipeline.from_pretrained(ckpt_dir_stallion, torch_dtype=torch.float16)
62
+
63
+ pipe_pony.to(device)
64
+ pipe_cyber.to(device)
65
+ pipe_stallion.to(device)
66
+
 
 
 
 
 
67
  @spaces.GPU(enable_queue=True)
68
  def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
 
69
  faceid_all_embeds = []
70
  first_iteration = True
71
  for image in images:
 
73
  faces = app.get(face)
74
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
75
  faceid_all_embeds.append(faceid_embed)
76
+ if(first_iteration and preserve_face_structure):
77
+ face_image = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224) # you can also segment the face
78
  first_iteration = False
79
+
80
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
81
+
82
  total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
83
+
84
+ if model_choice == "Pony Realism v21":
85
+ pipe = pipe_pony
86
+ elif model_choice == "Cyber Realistic Pony v61":
87
+ pipe = pipe_cyber
88
+ else: # "Stallion Dreams Pony Realistic v1"
89
+ pipe = pipe_stallion
90
+
91
+ if(not preserve_face_structure):
92
+ print("Generating normal")
93
  image = ip_model.generate(
94
+ prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
95
+ scale=likeness_strength, width=512, height=512, num_inference_steps=30, pipe=pipe
 
 
 
 
 
96
  )
97
  else:
98
+ print("Generating plus")
99
  image = ip_model_plus.generate(
100
+ prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
101
+ scale=likeness_strength, face_image=face_image, shortcut=True, s_scale=face_strength, width=512, height=512, num_inference_steps=30, pipe=pipe
 
 
 
 
 
 
 
 
102
  )
103
+ print(image)
104
  return image
105
 
106
  def change_style(style):
107
  if style == "Photorealistic":
108
+ return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0))
109
  else:
110
+ return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8))
111
 
112
  def swap_to_gallery(images):
113
  return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
 
126
  with gr.Row():
127
  with gr.Column():
128
  files = gr.Files(
129
+ label="Drag 1 or more photos of your face",
130
+ file_types=["image"]
131
+ )
132
  uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
133
  with gr.Column(visible=False) as clear_button:
134
  remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
135
+ prompt = gr.Textbox(label="Prompt",
136
+ info="Try something like 'a photo of a man/woman/person'",
137
+ placeholder="A photo of a [man/woman/person]...")
 
 
138
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
139
+ style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
 
 
 
 
 
140
  model_choice = gr.Dropdown(
141
+ ["Pony Realism v21", "Cyber Realistic Pony v61", "Stallion Dreams Pony Realistic v1"],
142
  label="Model Choice",
143
+ value="Pony Realism v21"
 
144
  )
145
  submit = gr.Button("Submit")
146
  with gr.Accordion(open=False, label="Advanced Options"):
147
+ preserve = gr.Checkbox(label="Preserve Face Structure", info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.", value=True)
148
+ face_strength = gr.Slider(label="Face Structure strength", info="Only applied if preserve face structure is checked", value=1.3, step=0.1, minimum=0, maximum=3)
 
 
 
 
 
 
 
 
 
 
 
149
  likeness_strength = gr.Slider(label="Face Embed strength", value=1.0, step=0.1, minimum=0, maximum=5)
150
+ nfaa_negative_prompts = gr.Textbox(label="Appended Negative Prompts", info="Negative prompts to steer generations towards safe for all audiences outputs", value="naked, bikini, skimpy, scanty, bare skin, lingerie, swimsuit, exposed, see-through")
 
 
 
 
151
  with gr.Column():
152
  gallery = gr.Gallery(label="Generated Images")
153
  style.change(fn=change_style,
154
+ inputs=style,
155
+ outputs=[preserve, face_strength, likeness_strength])
156
  files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
157
  remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
158
+ submit.click(fn=generate_image,
159
+ inputs=[files,prompt,negative_prompt,preserve, face_strength, likeness_strength, nfaa_negative_prompts, model_choice],
160
+ outputs=gallery)
 
 
161
 
162
  gr.Markdown("")
163
+
164
+ demo.launch()