yuki-imajuku commited on
Commit
59a5c16
·
verified ·
1 Parent(s): 7ef3ec7

Update evo_nishikie_v1.py

Browse files

Adding `EvoNishikieConditioningImageProcessor` class

Files changed (1) hide show
  1. evo_nishikie_v1.py +20 -7
evo_nishikie_v1.py CHANGED
@@ -1,7 +1,6 @@
1
  import gc
2
- from io import BytesIO
3
  import os
4
- from typing import Dict, List, Union
5
 
6
  from PIL import Image, ImageFilter
7
  from controlnet_aux import LineartDetector
@@ -11,7 +10,6 @@ from diffusers import (
11
  UNet2DConditionModel,
12
  )
13
  from huggingface_hub import hf_hub_download
14
- import requests
15
  import safetensors
16
  import torch
17
  from tqdm import tqdm
@@ -30,8 +28,17 @@ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
30
  # Evo-Nishikie
31
  NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
32
 
33
- # Threshold for image binarization
34
- BINARY_THRESHOLD = 40
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
@@ -118,7 +125,9 @@ def split_conv_attn(weights):
118
  return {"conv": conv_tensors, "attn": attn_tensors}
119
 
120
 
121
- def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
 
 
122
  # Load base models
123
  sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
124
  dpo_weights = split_conv_attn(
@@ -190,4 +199,8 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
190
  pipe.fuse_lora(lora_scale=1.0)
191
 
192
  pipe = pipe.to(device, dtype=torch.float16)
193
- return pipe
 
 
 
 
 
1
  import gc
 
2
  import os
3
+ from typing import Dict, List, Tuple, Union
4
 
5
  from PIL import Image, ImageFilter
6
  from controlnet_aux import LineartDetector
 
10
  UNet2DConditionModel,
11
  )
12
  from huggingface_hub import hf_hub_download
 
13
  import safetensors
14
  import torch
15
  from tqdm import tqdm
 
28
  # Evo-Nishikie
29
  NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
30
 
31
+
32
+ class EvoNishikieConditioningImageProcessor:
33
+ def __init__(self, device="cpu"):
34
+ self.lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators").to(device)
35
+ self.image_filter = ImageFilter.MedianFilter(size=3)
36
+
37
+ def __call__(self, original_image: Image.Image) -> Image.Image:
38
+ lineart_image = self.lineart_detector(original_image, coarse=False, image_resolution=1024)
39
+ lineart_image_filtered = lineart_image.filter(self.image_filter)
40
+ conditioning_image = lineart_image_filtered.point(lambda p: 255 if p > 40 else 0).convert("L")
41
+ return conditioning_image
42
 
43
 
44
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
 
125
  return {"conv": conv_tensors, "attn": attn_tensors}
126
 
127
 
128
+ def load_evo_nishikie(device="cuda", processor_device="cpu") -> Tuple[
129
+ StableDiffusionXLControlNetPipeline, EvoNishikieConditioningImageProcessor
130
+ ]:
131
  # Load base models
132
  sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
133
  dpo_weights = split_conv_attn(
 
199
  pipe.fuse_lora(lora_scale=1.0)
200
 
201
  pipe = pipe.to(device, dtype=torch.float16)
202
+
203
+ # Load conditioning image processor
204
+ processor = EvoNishikieConditioningImageProcessor(device=processor_device)
205
+
206
+ return pipe, processor