|
import base64 |
|
import json |
|
import sys |
|
from collections import defaultdict |
|
from io import BytesIO |
|
from pprint import pprint |
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
DPMSolverMultistepScheduler, |
|
DPMSolverSinglestepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
) |
|
from safetensors.torch import load_file |
|
from torch import autocast |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if device.type != "cuda": |
|
raise ValueError("need to run on GPU") |
|
|
|
|
|
class EndpointHandler: |
|
LORA_PATHS = { |
|
"hairdetailer": "isatis/kw/lora/hairdetailer.safetensors", |
|
"lora_leica": "isatis/kw/lora/lora_leica.safetensors", |
|
"epiNoiseoffset_v2": "isatis/kw/lora/epiNoiseoffset_v2.safetensors", |
|
"MBHU-TT2FRS": "isatis/kw/lora/MBHU-TT2FRS.safetensors", |
|
"ShinyOiledSkin_v20": "isatis/kw/lora/ShinyOiledSkin_v20-LoRA.safetensors", |
|
"polyhedron_new_skin_v1.1": "isatis/kw/lora/polyhedron_new_skin_v1.1.safetensors", |
|
"detailed_eye-10": "isatis/kw/lora/detailed_eye-10.safetensors", |
|
"add_detail": "isatis/kw/lora/add_detail.safetensors", |
|
"MuscleGirl_v1": "isatis/kw/lora/MuscleGirl_v1.safetensors", |
|
} |
|
|
|
TEXTUAL_INVERSION = [ |
|
{ |
|
"weight_name": "https://huggingface.co/isatis/kw/embeddings/EasyNegative.safetensors", |
|
"token": "easynegative", |
|
}, |
|
{"weight_name": "isatis/kw/embeddings/badhandv4.pt", "token": "badhandv4"}, |
|
{ |
|
"weight_name": "isatis/kw/embeddings/bad-artist-anime.pt", |
|
"token": "bad-artist-anime", |
|
}, |
|
{"weight_name": "isatis/kw/embeddings/NegfeetV2.pt", "token": "NegfeetV2"}, |
|
{ |
|
"weight_name": "isatis/kw/embeddings/ng_deepnegative_v1_75t.pt", |
|
"token": "ng_deepnegative_v1_75t", |
|
}, |
|
{"weight_name": "isatis/kw/embeddings/bad-hands-5.pt", "token": "bad-hands-5"}, |
|
] |
|
|
|
def __init__(self, path="."): |
|
|
|
self.pipe = DiffusionPipeline.from_pretrained( |
|
path, |
|
custom_pipeline="lpw_stable_diffusion", |
|
torch_dtype=torch.float16, |
|
) |
|
self.pipe = self.pipe.to(device) |
|
|
|
|
|
|
|
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( |
|
self.pipe.scheduler.config, |
|
use_karras_sigmas=True, |
|
algorithm_type="sde-dpmsolver++", |
|
) |
|
|
|
|
|
self.pipe.safety_checker = None |
|
|
|
|
|
self.load_embeddings() |
|
|
|
|
|
self.pipe = self.load_selected_loras( |
|
[ |
|
("polyhedron_new_skin_v1.1", 0.35), |
|
("detailed_eye-10", 0.3), |
|
("add_detail", 0.4), |
|
("MuscleGirl_v1", 0.3), |
|
], |
|
) |
|
|
|
|
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe.enable_attention_slicing() |
|
|
|
def load_lora(self, pipeline, lora_path, lora_weight=0.5): |
|
state_dict = load_file(lora_path) |
|
LORA_PREFIX_UNET = "lora_unet" |
|
LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
|
|
alpha = lora_weight |
|
visited = [] |
|
|
|
for key in state_dict: |
|
state_dict[key] = state_dict[key].to(device) |
|
|
|
|
|
for key in state_dict: |
|
|
|
if ".alpha" in key or key in visited: |
|
continue |
|
|
|
if "text" in key: |
|
layer_infos = ( |
|
key.split(".")[0] |
|
.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1] |
|
.split("_") |
|
) |
|
curr_layer = pipeline.text_encoder |
|
else: |
|
layer_infos = ( |
|
key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") |
|
) |
|
curr_layer = pipeline.unet |
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += "_" + layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
pair_keys = [] |
|
if "lora_down" in key: |
|
pair_keys.append(key.replace("lora_down", "lora_up")) |
|
pair_keys.append(key) |
|
else: |
|
pair_keys.append(key) |
|
pair_keys.append(key.replace("lora_up", "lora_down")) |
|
|
|
|
|
if len(state_dict[pair_keys[0]].shape) == 4: |
|
weight_up = ( |
|
state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
|
) |
|
weight_down = ( |
|
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
|
) |
|
curr_layer.weight.data += alpha * torch.mm( |
|
weight_up, weight_down |
|
).unsqueeze(2).unsqueeze(3) |
|
else: |
|
weight_up = state_dict[pair_keys[0]].to(torch.float32) |
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) |
|
|
|
|
|
for item in pair_keys: |
|
visited.append(item) |
|
|
|
return pipeline |
|
|
|
def load_embeddings(self): |
|
"""Load textual inversions, avoid bad prompts""" |
|
for model in EndpointHandler.TEXTUAL_INVERSION: |
|
self.pipe.load_textual_inversion( |
|
".", weight_name=model["weight_name"], token=model["token"] |
|
) |
|
|
|
def load_selected_loras(self, selections): |
|
"""Load Loras models, can lead to marvelous creations""" |
|
for model_name, weight in selections: |
|
lora_path = EndpointHandler.LORA_PATHS[model_name] |
|
self.pipe = self.load_lora( |
|
pipeline=self.pipe, lora_path=lora_path, lora_weight=weight |
|
) |
|
return self.pipe |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
Return: |
|
A :obj:`dict`:. base64 encoded image |
|
""" |
|
global device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
required_fields = [ |
|
"prompt", |
|
"negative_prompt", |
|
"width", |
|
"num_inference_steps", |
|
"height", |
|
"seed", |
|
"guidance_scale", |
|
] |
|
|
|
missing_fields = [field for field in required_fields if field not in data] |
|
|
|
if missing_fields: |
|
return { |
|
"flag": "error", |
|
"message": f"Missing fields: {', '.join(missing_fields)}", |
|
} |
|
|
|
|
|
prompt = data["prompt"] |
|
negative_prompt = data["negative_prompt"] |
|
loras_model = data.pop("loras_model", None) |
|
seed = data["seed"] |
|
width = data["width"] |
|
num_inference_steps = data["num_inference_steps"] |
|
height = data["height"] |
|
guidance_scale = data["guidance_scale"] |
|
|
|
|
|
forced_negative = ( |
|
negative_prompt |
|
+ """easynegative, badhandv4, bad-artist-anime, NegfeetV2, ng_deepnegative_v1_75t, bad-hands-5 """ |
|
) |
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(seed) if seed else None |
|
|
|
|
|
if loras_model: |
|
self.pipe = self.load_selected_loras(loras_model) |
|
|
|
try: |
|
|
|
with autocast(device.type): |
|
image = self.pipe.text2img( |
|
prompt=prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
height=height, |
|
width=width, |
|
negative_prompt=forced_negative, |
|
generator=generator, |
|
max_embeddings_multiples=5, |
|
).images[0] |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
|
|
return {"flag": "success", "image": img_str.decode()} |
|
|
|
except Exception as e: |
|
|
|
return {"flag": "error", "message": str(e)} |
|
|