Spaces:
Running
on
Zero
Running
on
Zero
Bobby
commited on
Commit
·
778d937
1
Parent(s):
26a9c47
new start
Browse files- .gitignore +1 -0
- anime_app.py +136 -135
- anime_model.py +157 -185
- preprocess_anime.py +49 -57
- requirements.txt +4 -4
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
venv/*
|
anime_app.py
CHANGED
@@ -1,135 +1,136 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from anime_model import Model
|
3 |
-
import spaces
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from anime_model import Model
|
3 |
+
import spaces
|
4 |
+
prod = False
|
5 |
+
port = 8080
|
6 |
+
show_options = True
|
7 |
+
if prod:
|
8 |
+
port = 8081
|
9 |
+
show_options = False
|
10 |
+
|
11 |
+
from settings import (
|
12 |
+
DEFAULT_IMAGE_RESOLUTION,
|
13 |
+
MAX_NUM_IMAGES,
|
14 |
+
MAX_SEED,
|
15 |
+
)
|
16 |
+
from utils import randomize_seed_fn
|
17 |
+
|
18 |
+
base_model = "nyxia/AAM-AnyLoRA-Anime-Mix"
|
19 |
+
model = Model(base_model_id=base_model, task_name="NormalBae")
|
20 |
+
|
21 |
+
# note: for high res 1024x1024, set num steps to 9 and guidance to 6
|
22 |
+
def auto_process_image(image, prompt):
|
23 |
+
a_prompt="anime style, cartoon, drawing, 2D anime, illustration, cartoon"
|
24 |
+
n_prompt="realism, 3d, BadDream, (UnrealisticDream:1.2), split image, multiple views, text, cropped, out of frame, worst quality, low quality, jpeg artifacts, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad proportions, extra limbs, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
|
25 |
+
num_samples=4
|
26 |
+
image_resolution=512
|
27 |
+
preprocess_resolution=512
|
28 |
+
num_steps=15
|
29 |
+
guidance_scale=4.5
|
30 |
+
seed=randomize_seed_fn(0, True)
|
31 |
+
config = [
|
32 |
+
image,
|
33 |
+
prompt,
|
34 |
+
a_prompt,
|
35 |
+
n_prompt,
|
36 |
+
num_samples,
|
37 |
+
image_resolution,
|
38 |
+
preprocess_resolution,
|
39 |
+
num_steps,
|
40 |
+
guidance_scale,
|
41 |
+
seed,
|
42 |
+
]
|
43 |
+
if image is None:
|
44 |
+
return None
|
45 |
+
print("processing image")
|
46 |
+
config[0] = image
|
47 |
+
config[1] = prompt
|
48 |
+
# print(config)
|
49 |
+
return model.process_normal(*config)
|
50 |
+
|
51 |
+
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
|
52 |
+
with gr.Row():
|
53 |
+
# examples
|
54 |
+
gr.Text(label="Anime Style Examples", value="Weeb!")
|
55 |
+
with gr.Row():
|
56 |
+
with gr.Column():
|
57 |
+
# input text
|
58 |
+
prompt = gr.Textbox(label="Anime Style", placeholder="anime tittes")
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column():
|
61 |
+
# input image
|
62 |
+
image = gr.Image(label="Input", sources=['upload'], show_label=True, format="jpeg")
|
63 |
+
with gr.Column():
|
64 |
+
# output
|
65 |
+
result = gr.Gallery(label="Anime", show_label=True, columns=2, scale=3, object_fit="contain", format="jpeg")
|
66 |
+
with gr.Column():
|
67 |
+
# run button
|
68 |
+
run_button = gr.Button(size=["lg"])
|
69 |
+
with gr.Row():
|
70 |
+
with gr.Accordion("Advanced options", open=show_options, visible=show_options):
|
71 |
+
num_samples = gr.Slider(
|
72 |
+
label="Images", minimum=1, maximum=MAX_NUM_IMAGES, value=4, step=1
|
73 |
+
)
|
74 |
+
image_resolution = gr.Slider(
|
75 |
+
label="Image resolution",
|
76 |
+
minimum=256,
|
77 |
+
maximum=1024,
|
78 |
+
value=DEFAULT_IMAGE_RESOLUTION,
|
79 |
+
step=256,
|
80 |
+
)
|
81 |
+
preprocess_resolution = gr.Slider(
|
82 |
+
label="Preprocess resolution", minimum=128, maximum=1024, value=512, step=1
|
83 |
+
)
|
84 |
+
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=15, step=1) # 20/4.5 or 12 without lora, 4 with lora
|
85 |
+
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=4.5, step=0.1) #5 without lora, 2 with lora
|
86 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
87 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
88 |
+
a_prompt = gr.Textbox(label="Additional prompt", value="anime style, cartoon, drawing, 2D anime, illustration, cartoon")
|
89 |
+
n_prompt = gr.Textbox(
|
90 |
+
label="Negative prompt",
|
91 |
+
# value="BadDream, (UnrealisticDream:1.2), split image, multiple views, text, cropped, out of frame, worst quality, low quality, jpeg artifacts, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad proportions, extra limbs, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
|
92 |
+
value="(signature:1.2),(artist name:1.2),(watermark:1.2), (easynegative), (low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2),bad composition, inaccurate eyes, extra digit,fewer digits, (extra arms:1.2), badhandv4,clothes"
|
93 |
+
)
|
94 |
+
config = [
|
95 |
+
image,
|
96 |
+
prompt,
|
97 |
+
a_prompt,
|
98 |
+
n_prompt,
|
99 |
+
num_samples,
|
100 |
+
image_resolution,
|
101 |
+
preprocess_resolution,
|
102 |
+
num_steps,
|
103 |
+
guidance_scale,
|
104 |
+
seed,
|
105 |
+
]
|
106 |
+
prompt.submit(
|
107 |
+
fn=randomize_seed_fn,
|
108 |
+
inputs=[seed, randomize_seed],
|
109 |
+
outputs=seed,
|
110 |
+
queue=False,
|
111 |
+
api_name=False,
|
112 |
+
show_progress="minimal",
|
113 |
+
).then(
|
114 |
+
fn=model.process_normal,
|
115 |
+
inputs=config,
|
116 |
+
outputs=result,
|
117 |
+
api_name=False,
|
118 |
+
show_progress="minimal"
|
119 |
+
)
|
120 |
+
run_button.click(
|
121 |
+
fn=randomize_seed_fn,
|
122 |
+
inputs=[seed, randomize_seed],
|
123 |
+
outputs=seed,
|
124 |
+
queue=False,
|
125 |
+
api_name=False,
|
126 |
+
show_progress="minimal"
|
127 |
+
).then(
|
128 |
+
fn=model.process_normal,
|
129 |
+
inputs=config,
|
130 |
+
outputs=result,
|
131 |
+
show_progress="minimal"
|
132 |
+
)
|
133 |
+
image.change(auto_process_image, inputs=[image, prompt], outputs=[result])
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
|
anime_model.py
CHANGED
@@ -1,185 +1,157 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import numpy as np
|
6 |
-
import PIL.Image
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
class Model:
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
self.
|
24 |
-
self.
|
25 |
-
self.
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
def
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
image_resolution: int,
|
159 |
-
preprocess_resolution: int,
|
160 |
-
num_steps: int,
|
161 |
-
guidance_scale: float,
|
162 |
-
seed: int,
|
163 |
-
) -> list[PIL.Image.Image]:
|
164 |
-
if image is None:
|
165 |
-
raise ValueError
|
166 |
-
#if image_resolution > MAX_IMAGE_RESOLUTION:
|
167 |
-
# raise ValueError
|
168 |
-
#if num_images > MAX_NUM_IMAGES:
|
169 |
-
# raise ValueError
|
170 |
-
#self.load("NormalBae")
|
171 |
-
model = NormalBaeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
|
172 |
-
torch.cuda.empty_cache()
|
173 |
-
gc.collect()
|
174 |
-
if prompt == "":
|
175 |
-
prompt = "anime girl"
|
176 |
-
print(prompt)
|
177 |
-
return run_pipe(
|
178 |
-
prompt=self.get_prompt("Hentai Nude Anime Titties of " + prompt, additional_prompt),
|
179 |
-
negative_prompt=negative_prompt,
|
180 |
-
control_image=image,
|
181 |
-
num_images=num_images,
|
182 |
-
num_steps=num_steps,
|
183 |
-
guidance_scale=guidance_scale,
|
184 |
-
seed=seed,
|
185 |
-
)
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import gc
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
import PIL.Image
|
7 |
+
import torch
|
8 |
+
import spaces
|
9 |
+
from diffusers import (
|
10 |
+
ControlNetModel,
|
11 |
+
AutoencoderKL,
|
12 |
+
StableDiffusionControlNetPipeline,
|
13 |
+
DPMSolverMultistepScheduler,
|
14 |
+
)
|
15 |
+
from preprocess_anime import Preprocessor
|
16 |
+
from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
|
17 |
+
|
18 |
+
@spaces.GPU
|
19 |
+
class Model:
|
20 |
+
def __init__(self, base_model_id, task_name):
|
21 |
+
print("Initializing base model: ", base_model_id, " with ", task_name)
|
22 |
+
self.base_model_id = base_model_id
|
23 |
+
self.task_name = task_name
|
24 |
+
self.pipe = self.load_pipe(base_model_id, task_name)
|
25 |
+
self.preprocessor = Preprocessor()
|
26 |
+
|
27 |
+
def load_pipe(self, base_model_id, task_name):
|
28 |
+
print("loading pipe")
|
29 |
+
# Controlnet
|
30 |
+
model_id = "lllyasviel/control_v11p_sd15_normalbae"
|
31 |
+
print("initializing controlnet")
|
32 |
+
controlnet = ControlNetModel.from_pretrained(
|
33 |
+
model_id,
|
34 |
+
torch_dtype=torch.float16,
|
35 |
+
attn_implementation="flash_attention_2",
|
36 |
+
).to("cuda")
|
37 |
+
controlnet.to(memory_format=torch.channels_last)
|
38 |
+
|
39 |
+
# Scheduler
|
40 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
41 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
42 |
+
subfolder="scheduler",
|
43 |
+
use_karras_sigmas=True,
|
44 |
+
# final_sigmas_type="sigma_min",
|
45 |
+
algorithm_type="sde-dpmsolver++",
|
46 |
+
# prediction_type="epsilon",
|
47 |
+
# thresholding=False,
|
48 |
+
denoise_final=True,
|
49 |
+
device_map="cuda",
|
50 |
+
attn_implementation="flash_attention_2",
|
51 |
+
)
|
52 |
+
|
53 |
+
#vae
|
54 |
+
vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
55 |
+
vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
|
56 |
+
vae.to(memory_format=torch.channels_last)
|
57 |
+
# Stable Diffusion Pipeline
|
58 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
59 |
+
base_model_id,
|
60 |
+
safety_checker=None,
|
61 |
+
controlnet=controlnet,
|
62 |
+
scheduler=scheduler,
|
63 |
+
vae=vae,
|
64 |
+
torch_dtype=torch.float16,
|
65 |
+
).to("cuda")
|
66 |
+
|
67 |
+
# efficiency optimizations - DO NOT CHANGE ORDER
|
68 |
+
pipe.enable_xformers_memory_efficient_attention()
|
69 |
+
|
70 |
+
# lora
|
71 |
+
# pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
|
72 |
+
# pipe.load_lora_weights("Lykon/AnyLoRA", weight_name="AnyLoRA_bakedVae_blessed_fp16.safetensors")
|
73 |
+
# pipe.load_lora_weights("Lykon/AnyLoRA", weight_name="AnyLoRA_noVae_fp16-pruned.safetensors")
|
74 |
+
# pipe.fuse_lora()
|
75 |
+
# pipe.unet.to(memory_format=torch.channels_last)
|
76 |
+
|
77 |
+
torch.cuda.empty_cache()
|
78 |
+
gc.collect()
|
79 |
+
self.base_model_id = base_model_id
|
80 |
+
self.task_name = task_name
|
81 |
+
|
82 |
+
return pipe
|
83 |
+
|
84 |
+
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
|
85 |
+
if not prompt:
|
86 |
+
prompt = additional_prompt
|
87 |
+
else:
|
88 |
+
prompt = f"{prompt}, {additional_prompt}"
|
89 |
+
return prompt
|
90 |
+
|
91 |
+
@torch.inference_mode()
|
92 |
+
def run_pipe(
|
93 |
+
self,
|
94 |
+
prompt: str,
|
95 |
+
negative_prompt: str,
|
96 |
+
control_image: PIL.Image.Image,
|
97 |
+
num_images: int,
|
98 |
+
num_steps: int,
|
99 |
+
guidance_scale: float,
|
100 |
+
seed: int,
|
101 |
+
) -> list[PIL.Image.Image]:
|
102 |
+
generator = torch.cuda.manual_seed(seed)
|
103 |
+
torch.cuda.synchronize()
|
104 |
+
start = time.time()
|
105 |
+
results = self.pipe(
|
106 |
+
prompt=prompt,
|
107 |
+
negative_prompt=negative_prompt,
|
108 |
+
guidance_scale=guidance_scale,
|
109 |
+
num_images_per_prompt=num_images,
|
110 |
+
num_inference_steps=num_steps,
|
111 |
+
generator=generator,
|
112 |
+
image=control_image,
|
113 |
+
).images
|
114 |
+
print(f"Inference done in: {time.time() - start:.2f} seconds")
|
115 |
+
print(f"Prompt {prompt}")
|
116 |
+
torch.cuda.synchronize()
|
117 |
+
torch.cuda.empty_cache()
|
118 |
+
gc.collect()
|
119 |
+
|
120 |
+
return results
|
121 |
+
def process_normal(
|
122 |
+
self,
|
123 |
+
image: np.ndarray,
|
124 |
+
prompt: str,
|
125 |
+
additional_prompt: str,
|
126 |
+
negative_prompt: str,
|
127 |
+
num_images: int,
|
128 |
+
image_resolution: int,
|
129 |
+
preprocess_resolution: int,
|
130 |
+
num_steps: int,
|
131 |
+
guidance_scale: float,
|
132 |
+
seed: int,
|
133 |
+
) -> list[PIL.Image.Image]:
|
134 |
+
if image is None:
|
135 |
+
raise ValueError
|
136 |
+
if image_resolution > MAX_IMAGE_RESOLUTION:
|
137 |
+
raise ValueError
|
138 |
+
if num_images > MAX_NUM_IMAGES:
|
139 |
+
raise ValueError
|
140 |
+
self.preprocessor.load("NormalBae")
|
141 |
+
control_image = self.preprocessor(
|
142 |
+
image=image,
|
143 |
+
image_resolution=image_resolution,
|
144 |
+
detect_resolution=preprocess_resolution,
|
145 |
+
)
|
146 |
+
if prompt == "":
|
147 |
+
prompt = "anime girl"
|
148 |
+
print(prompt)
|
149 |
+
return self.run_pipe(
|
150 |
+
prompt=self.get_prompt("Hentai Photo from imgur of " + prompt, additional_prompt),
|
151 |
+
negative_prompt=negative_prompt,
|
152 |
+
control_image=control_image,
|
153 |
+
num_images=num_images,
|
154 |
+
num_steps=num_steps,
|
155 |
+
guidance_scale=guidance_scale,
|
156 |
+
seed=seed,
|
157 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocess_anime.py
CHANGED
@@ -1,57 +1,49 @@
|
|
1 |
-
import gc
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import PIL.Image
|
5 |
-
import torch
|
6 |
-
from controlnet_aux import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
image = self.model(image, **kwargs)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
image = HWC3(image)
|
51 |
-
image = resize_image(image, resolution=detect_resolution)
|
52 |
-
image = self.model(image, **kwargs)
|
53 |
-
image = HWC3(image)
|
54 |
-
image = resize_image(image, resolution=image_resolution)
|
55 |
-
return PIL.Image.fromarray(image)
|
56 |
-
else:
|
57 |
-
return self.model(image, **kwargs)
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
import torch
|
6 |
+
from controlnet_aux import NormalBaeDetector
|
7 |
+
|
8 |
+
from controlnet_aux.util import HWC3
|
9 |
+
from cv_utils import resize_image
|
10 |
+
|
11 |
+
class Preprocessor:
|
12 |
+
MODEL_ID = "lllyasviel/Annotators"
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
self.model = None
|
16 |
+
self.name = ""
|
17 |
+
|
18 |
+
def load(self, name: str) -> None:
|
19 |
+
if name == self.name:
|
20 |
+
return
|
21 |
+
elif name == "NormalBae":
|
22 |
+
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
|
23 |
+
else:
|
24 |
+
raise ValueError
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
gc.collect()
|
27 |
+
self.name = name
|
28 |
+
|
29 |
+
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
|
30 |
+
if self.name == "Canny":
|
31 |
+
if "detect_resolution" in kwargs:
|
32 |
+
detect_resolution = kwargs.pop("detect_resolution")
|
33 |
+
image = np.array(image)
|
34 |
+
image = HWC3(image)
|
35 |
+
image = resize_image(image, resolution=detect_resolution)
|
36 |
+
image = self.model(image, **kwargs)
|
37 |
+
return PIL.Image.fromarray(image)
|
38 |
+
elif self.name == "Midas":
|
39 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
40 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
41 |
+
image = np.array(image)
|
42 |
+
image = HWC3(image)
|
43 |
+
image = resize_image(image, resolution=detect_resolution)
|
44 |
+
image = self.model(image, **kwargs)
|
45 |
+
image = HWC3(image)
|
46 |
+
image = resize_image(image, resolution=image_resolution)
|
47 |
+
return PIL.Image.fromarray(image)
|
48 |
+
else:
|
49 |
+
return self.model(image, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
torch==2.1.2
|
2 |
-
torchvision
|
3 |
-
torchaudio
|
4 |
diffusers==0.27.2
|
5 |
einops==0.6.1
|
6 |
gradio==4.26.0
|
@@ -10,6 +10,6 @@ mediapipe==0.10.1
|
|
10 |
opencv-python-headless==4.8.0.74
|
11 |
safetensors==0.4.2
|
12 |
transformers==4.39.3
|
13 |
-
xformers==0.0.23.post1
|
14 |
accelerate==0.29.1
|
15 |
#controlnet_aux==0.0.7
|
|
|
1 |
+
#torch==2.1.2
|
2 |
+
#torchvision
|
3 |
+
#torchaudio
|
4 |
diffusers==0.27.2
|
5 |
einops==0.6.1
|
6 |
gradio==4.26.0
|
|
|
10 |
opencv-python-headless==4.8.0.74
|
11 |
safetensors==0.4.2
|
12 |
transformers==4.39.3
|
13 |
+
#xformers==0.0.23.post1
|
14 |
accelerate==0.29.1
|
15 |
#controlnet_aux==0.0.7
|