yuki-imajuku
commited on
Update evo_nishikie_v1.py
Browse filesAdding `EvoNishikieConditioningImageProcessor` class
- evo_nishikie_v1.py +20 -7
evo_nishikie_v1.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gc
|
2 |
-
from io import BytesIO
|
3 |
import os
|
4 |
-
from typing import Dict, List, Union
|
5 |
|
6 |
from PIL import Image, ImageFilter
|
7 |
from controlnet_aux import LineartDetector
|
@@ -11,7 +10,6 @@ from diffusers import (
|
|
11 |
UNet2DConditionModel,
|
12 |
)
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
-
import requests
|
15 |
import safetensors
|
16 |
import torch
|
17 |
from tqdm import tqdm
|
@@ -30,8 +28,17 @@ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
|
|
30 |
# Evo-Nishikie
|
31 |
NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
|
32 |
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
@@ -118,7 +125,9 @@ def split_conv_attn(weights):
|
|
118 |
return {"conv": conv_tensors, "attn": attn_tensors}
|
119 |
|
120 |
|
121 |
-
def load_evo_nishikie(device="cuda") ->
|
|
|
|
|
122 |
# Load base models
|
123 |
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
|
124 |
dpo_weights = split_conv_attn(
|
@@ -190,4 +199,8 @@ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
|
|
190 |
pipe.fuse_lora(lora_scale=1.0)
|
191 |
|
192 |
pipe = pipe.to(device, dtype=torch.float16)
|
193 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import gc
|
|
|
2 |
import os
|
3 |
+
from typing import Dict, List, Tuple, Union
|
4 |
|
5 |
from PIL import Image, ImageFilter
|
6 |
from controlnet_aux import LineartDetector
|
|
|
10 |
UNet2DConditionModel,
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
|
|
13 |
import safetensors
|
14 |
import torch
|
15 |
from tqdm import tqdm
|
|
|
28 |
# Evo-Nishikie
|
29 |
NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
|
30 |
|
31 |
+
|
32 |
+
class EvoNishikieConditioningImageProcessor:
|
33 |
+
def __init__(self, device="cpu"):
|
34 |
+
self.lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators").to(device)
|
35 |
+
self.image_filter = ImageFilter.MedianFilter(size=3)
|
36 |
+
|
37 |
+
def __call__(self, original_image: Image.Image) -> Image.Image:
|
38 |
+
lineart_image = self.lineart_detector(original_image, coarse=False, image_resolution=1024)
|
39 |
+
lineart_image_filtered = lineart_image.filter(self.image_filter)
|
40 |
+
conditioning_image = lineart_image_filtered.point(lambda p: 255 if p > 40 else 0).convert("L")
|
41 |
+
return conditioning_image
|
42 |
|
43 |
|
44 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
|
|
125 |
return {"conv": conv_tensors, "attn": attn_tensors}
|
126 |
|
127 |
|
128 |
+
def load_evo_nishikie(device="cuda", processor_device="cpu") -> Tuple[
|
129 |
+
StableDiffusionXLControlNetPipeline, EvoNishikieConditioningImageProcessor
|
130 |
+
]:
|
131 |
# Load base models
|
132 |
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
|
133 |
dpo_weights = split_conv_attn(
|
|
|
199 |
pipe.fuse_lora(lora_scale=1.0)
|
200 |
|
201 |
pipe = pipe.to(device, dtype=torch.float16)
|
202 |
+
|
203 |
+
# Load conditioning image processor
|
204 |
+
processor = EvoNishikieConditioningImageProcessor(device=processor_device)
|
205 |
+
|
206 |
+
return pipe, processor
|