sd / app.py
zhengqilin
add init
1976a91
raw
history blame
5.12 kB
# inference handler for lightning ai
import re
import os
import logging
# import json
from pydantic import BaseModel
from typing import Any, Dict, Optional, TYPE_CHECKING
from dataclasses import dataclass
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import lightning as L
from lightning.app.components.serve import PythonServer, Text
from lightning.app import BuildConfig
class _DefaultInputData(BaseModel):
prompt: str
class _DefaultOutputData(BaseModel):
img_data: str
parameters: str
@dataclass
class CustomBuildConfig(BuildConfig):
def build_commands(self):
dir_path = "/content/"
model_path = os.path.join(dir_path, "models/Stable-diffusion")
model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors"
download_cmd = "wget -P {} {}".format(str(model_path), model_url)
vae_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/VAE/vae-ft-mse-840000-ema-pruned.ckpt"
vae_path = os.path.join(dir_path, "models/VAE")
down2 = "wget -P {} {}".format(str(vae_path), vae_url)
lora_url1 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/koreanDollLikeness_v10.safetensors"
lora_url2 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/taiwanDollLikeness_v10.safetensors"
lora_path = os.path.join(dir_path, "models/Lora")
down3 = "wget -P {} {}".format(str(lora_path), lora_url1)
down4 = "wget -P {} {}".format(str(lora_path), lora_url2)
# https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo
cmd1 = "pip3 install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117"
cmd2 = "pip3 install torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117"
cmd_31 = "sudo apt-get update"
cmd3 = "sudo apt-get install libgl1-mesa-glx"
cmd4 = "sudo apt-get install libglib2.0-0"
return [download_cmd, down2, down3, down4, cmd1, cmd2, cmd_31, cmd3, cmd4]
class PyTorchServer(PythonServer):
def __init__(
self,
input_type: type = _DefaultInputData,
output_type: type = _DefaultOutputData,
**kwargs: Any,
):
super().__init__(input_type=input_type, output_type=output_type, **kwargs)
# Use the custom build config
self.cloud_build_config = CustomBuildConfig()
def setup(self):
# need to install dependancies first to import packages
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from handler import initialize
initialize()
def predict(self, request):
from modules.api.api import encode_pil_to_base64
from modules import shared
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
args = {
"do_not_save_samples": True,
"do_not_save_grid": True,
"outpath_samples": "/content/desktop",
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed",
"sampler_name": "DPM++ SDE Karras",
"steps": 20, # 25
"cfg_scale": 8,
"width": 512,
"height": 768,
"seed": -1,
}
print("&&&&&&&&&&&&&&&&&&&&&&&&",request)
if request.prompt:
prompt = request.prompt
print("get prompt from request: ", prompt)
args["prompt"] = prompt
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
processed = process_images(p)
single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
return {
"img_data": single_image_b64,
"parameters": processed.images[0].info.get('parameters', ""),
}
component = PyTorchServer(
cloud_compute=L.CloudCompute('gpu', disk_size=20, idle_timeout=30)
)
# lightning run app app.py --cloud
app = L.LightningApp(component)