JackAILab commited on
Commit
3d0b447
1 Parent(s): 8eb54f6

Update pipline_StableDiffusion_ConsistentID.py

Browse files
pipline_StableDiffusion_ConsistentID.py CHANGED
@@ -17,11 +17,6 @@ from functions import process_text_with_markers, masks_for_unique_values, fetch_
17
  from functions import ProjPlusModel, masks_for_unique_values
18
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
19
  from huggingface_hub import hf_hub_download
20
- ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
21
- ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
22
- ### Thanks for the open source of face-parsing model.
23
- from models.BiSeNet.model import BiSeNet
24
- bise_net_cp_path = hf_hub_download(repo_id="JackAILab/ConsistentID", filename="face_parsing.pth", repo_type="model")
25
 
26
  PipelineImageInput = Union[
27
  PIL.Image.Image,
@@ -36,6 +31,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
36
  def load_ConsistentID_model(
37
  self,
38
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
 
39
  weight_name: str,
40
  subfolder: str = '',
41
  trigger_word_ID: str = '<|image|>',
@@ -63,10 +59,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
63
  self.app.prepare(ctx_id=0, det_size=(640, 640))
64
 
65
  ### BiSeNet
66
- self.bise_net = BiSeNet(n_classes = 19)
67
  # self.bise_net.cuda() # CUDA must not be initialized in the main process on Spaces with Stateless GPU environment
68
- self.bise_net_cp=bise_net_cp_path
69
- self.bise_net.load_state_dict(torch.load(self.bise_net_cp), map_location=torch.device('cpu'))
 
70
  self.bise_net.eval()
71
  # Colors for all 20 parts
72
  self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
 
17
  from functions import ProjPlusModel, masks_for_unique_values
18
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
19
  from huggingface_hub import hf_hub_download
 
 
 
 
 
20
 
21
  PipelineImageInput = Union[
22
  PIL.Image.Image,
 
31
  def load_ConsistentID_model(
32
  self,
33
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
34
+ bise_net,
35
  weight_name: str,
36
  subfolder: str = '',
37
  trigger_word_ID: str = '<|image|>',
 
59
  self.app.prepare(ctx_id=0, det_size=(640, 640))
60
 
61
  ### BiSeNet
62
+ # self.bise_net = BiSeNet(n_classes = 19)
63
  # self.bise_net.cuda() # CUDA must not be initialized in the main process on Spaces with Stateless GPU environment
64
+ # self.bise_net_cp=bise_net_cp_path
65
+ # self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
66
+ self.bise_net = bise_net # load from outside
67
  self.bise_net.eval()
68
  # Colors for all 20 parts
69
  self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],