multimodalart HF staff commited on
Commit
f5d25ef
·
verified ·
1 Parent(s): 5771eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -12
app.py CHANGED
@@ -1,16 +1,32 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionXLPipeline, AutoencoderKL
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
  from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
 
8
  import lora
9
  import copy
10
  import json
11
  import gc
12
  import random
13
  from urllib.parse import quote
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  with open("sdxl_loras.json", "r") as file:
15
  data = json.load(file)
16
  sdxl_loras_raw = [
@@ -52,16 +68,60 @@ sdxl_loras_raw_new = [item for item in sdxl_loras_raw if item.get("new") == True
52
 
53
  sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
54
 
55
- lcm_lora_id = "lcm-sd/lcm-sdxl-base-1.0-lora"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  vae = AutoencoderKL.from_pretrained(
58
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
59
  )
60
- pipe = StableDiffusionXLPipeline.from_pretrained(
61
  "stabilityai/stable-diffusion-xl-base-1.0",
62
  vae=vae,
63
  torch_dtype=torch.float16,
64
  )
 
65
  original_pipe = copy.deepcopy(pipe)
66
  pipe.to(device)
67
 
@@ -162,9 +222,22 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
162
  del lora_model
163
  gc.collect()
164
 
165
- def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_loras_new, progress=gr.Progress(track_tqdm=True)):
166
  global last_lora, last_merged, last_fused, pipe
167
- print("Index when running ", selected_state.index)
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  if(selected_state.index < 0):
169
  if(selected_state.index == -9999):
170
  selected_state.index = 0
@@ -194,19 +267,22 @@ def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_lora
194
  if(is_pivotal):
195
  #Add the textual inversion embeddings from pivotal tuning models
196
  text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
197
- text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
198
- tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
199
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
200
- embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
201
- embhandler.load_embeddings(embedding_path)
202
 
203
  image = pipe(
204
  prompt=prompt,
205
  negative_prompt=negative,
206
  width=1024,
207
  height=1024,
 
 
 
 
208
  num_inference_steps=20,
209
- guidance_scale=7.5,
 
210
  ).images[0]
211
  last_lora = repo_name
212
  gc.collect()
@@ -332,7 +408,7 @@ with gr.Blocks(css="custom.css") as demo:
332
  show_progress=False
333
  ).success(
334
  fn=run_lora,
335
- inputs=[prompt, negative, weight, selected_state, gr_sdxl_loras, gr_sdxl_loras_new],
336
  outputs=[result, share_group],
337
  )
338
  button.click(
 
1
  import gradio as gr
2
  import torch
 
3
  from huggingface_hub import hf_hub_download
4
  from safetensors.torch import load_file
5
  from share_btn import community_icon_html, loading_icon_html, share_js
6
  from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
7
+
8
  import lora
9
  import copy
10
  import json
11
  import gc
12
  import random
13
  from urllib.parse import quote
14
+ import gdown
15
+ import os
16
+
17
+ import diffusers
18
+ from diffusers.utils import load_image
19
+ from diffusers.models import ControlNetModel
20
+ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler
21
+ import cv2
22
+ import torch
23
+ import numpy as np
24
+ from PIL import Image
25
+
26
+ from insightface.app import FaceAnalysis
27
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
28
+ from controlnet_aux import ZoeDetector
29
+
30
  with open("sdxl_loras.json", "r") as file:
31
  data = json.load(file)
32
  sdxl_loras_raw = [
 
68
 
69
  sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
70
 
71
+ # download models
72
+ hf_hub_download(
73
+ repo_id="InstantX/InstantID",
74
+ filename="ControlNetModel/config.json",
75
+ local_dir="/data/checkpoints",
76
+ )
77
+ hf_hub_download(
78
+ repo_id="InstantX/InstantID",
79
+ filename="ControlNetModel/diffusion_pytorch_model.safetensors",
80
+ local_dir="/data/checkpoints",
81
+ )
82
+ hf_hub_download(
83
+ repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
84
+ )
85
+ hf_hub_download(
86
+ repo_id="latent-consistency/lcm-lora-sdxl",
87
+ filename="pytorch_lora_weights.safetensors",
88
+ local_dir="/data/checkpoints",
89
+ )
90
+ # download antelopev2
91
+ gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output=".", quiet=False, fuzzy=True)
92
+ # unzip antelopev2.zip
93
+ os.system("unzip .antelopev2.zip -d /data/models/")
94
+
95
+ app = FaceAnalysis(name='antelopev2', root='./data', providers=['CPUExecutionProvider'])
96
+ app.prepare(ctx_id=0, det_size=(640, 640))
97
+
98
+ # prepare models under ./checkpoints
99
+ face_adapter = f'/data/checkpoints/ip-adapter.bin'
100
+ controlnet_path = f'/data/checkpoints/ControlNetModel'
101
+
102
+ # load IdentityNet
103
+ identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
104
+ zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
105
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
106
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("rubbrband/albedobaseXL_v21",
107
+ vae=vae,
108
+ controlnet=[identitynet, zoedepthnet],
109
+ torch_dtype=torch.float16)
110
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
111
+ pipe.load_ip_adapter_instantid(face_adapter)
112
+ pipe.set_ip_adapter_scale(0.8)
113
+ zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
114
+ zoe.to("cuda")
115
 
116
  vae = AutoencoderKL.from_pretrained(
117
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
118
  )
119
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
120
  "stabilityai/stable-diffusion-xl-base-1.0",
121
  vae=vae,
122
  torch_dtype=torch.float16,
123
  )
124
+
125
  original_pipe = copy.deepcopy(pipe)
126
  pipe.to(device)
127
 
 
222
  del lora_model
223
  gc.collect()
224
 
225
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, sdxl_loras, sdxl_loras_new, progress=gr.Progress(track_tqdm=True)):
226
  global last_lora, last_merged, last_fused, pipe
227
+
228
+ face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
229
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
230
+ face_emb = face_info['embedding']
231
+ face_kps = draw_kps(face_image, face_info['kps'])
232
+
233
+ #prepare face zoe
234
+ with torch.no_grad():
235
+ image_zoe = zoe(face_image)
236
+
237
+ width, height = face_kps.size
238
+ images = [face_kps, image_zoe.resize((height, width))]
239
+
240
+
241
  if(selected_state.index < 0):
242
  if(selected_state.index == -9999):
243
  selected_state.index = 0
 
267
  if(is_pivotal):
268
  #Add the textual inversion embeddings from pivotal tuning models
269
  text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
270
+ state_dict_embedding = load_file(text_embedding_name)
271
+ pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
272
+ pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
 
 
273
 
274
  image = pipe(
275
  prompt=prompt,
276
  negative_prompt=negative,
277
  width=1024,
278
  height=1024,
279
+ image_embeds=face_emb,
280
+ image=face_image,
281
+ strength=0.85,
282
+ control_image=images,
283
  num_inference_steps=20,
284
+ guidance_scale = 7,
285
+ controlnet_conditioning_scale=[0.8, 0.8],
286
  ).images[0]
287
  last_lora = repo_name
288
  gc.collect()
 
408
  show_progress=False
409
  ).success(
410
  fn=run_lora,
411
+ inputs=[photo, prompt, negative, weight, selected_state, gr_sdxl_loras, gr_sdxl_loras_new],
412
  outputs=[result, share_group],
413
  )
414
  button.click(