|
|
|
|
|
import re |
|
import os |
|
import logging |
|
|
|
from pydantic import BaseModel |
|
from typing import Any, Dict, Optional, TYPE_CHECKING |
|
from dataclasses import dataclass |
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) |
|
|
|
import lightning as L |
|
from lightning.app.components.serve import PythonServer, Text |
|
from lightning.app import BuildConfig |
|
|
|
|
|
class _DefaultInputData(BaseModel): |
|
prompt: str |
|
|
|
class _DefaultOutputData(BaseModel): |
|
img_data: str |
|
parameters: str |
|
|
|
|
|
@dataclass |
|
class CustomBuildConfig(BuildConfig): |
|
def build_commands(self): |
|
dir_path = "/content/" |
|
model_path = os.path.join(dir_path, "models/Stable-diffusion") |
|
model_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors" |
|
download_cmd = "wget -P {} {}".format(str(model_path), model_url) |
|
vae_url = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/VAE/vae-ft-mse-840000-ema-pruned.ckpt" |
|
vae_path = os.path.join(dir_path, "models/VAE") |
|
down2 = "wget -P {} {}".format(str(vae_path), vae_url) |
|
lora_url1 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/koreanDollLikeness_v10.safetensors" |
|
lora_url2 = "https://huggingface.co/Hardy01/chill_watcher/resolve/main/models/Lora/taiwanDollLikeness_v10.safetensors" |
|
lora_path = os.path.join(dir_path, "models/Lora") |
|
down3 = "wget -P {} {}".format(str(lora_path), lora_url1) |
|
down4 = "wget -P {} {}".format(str(lora_path), lora_url2) |
|
|
|
cmd1 = "pip3 install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117" |
|
cmd2 = "pip3 install torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117" |
|
cmd_31 = "sudo apt-get update" |
|
cmd3 = "sudo apt-get install libgl1-mesa-glx" |
|
cmd4 = "sudo apt-get install libglib2.0-0" |
|
return [download_cmd, down2, down3, down4, cmd1, cmd2, cmd_31, cmd3, cmd4] |
|
|
|
|
|
class PyTorchServer(PythonServer): |
|
def __init__( |
|
self, |
|
input_type: type = _DefaultInputData, |
|
output_type: type = _DefaultOutputData, |
|
**kwargs: Any, |
|
): |
|
super().__init__(input_type=input_type, output_type=output_type, **kwargs) |
|
|
|
|
|
self.cloud_build_config = CustomBuildConfig() |
|
def setup(self): |
|
|
|
import torch |
|
|
|
if ".dev" in torch.__version__ or "+git" in torch.__version__: |
|
torch.__long_version__ = torch.__version__ |
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) |
|
|
|
from handler import initialize |
|
initialize() |
|
|
|
def predict(self, request): |
|
from modules.api.api import encode_pil_to_base64 |
|
from modules import shared |
|
from modules.processing import StableDiffusionProcessingTxt2Img, process_images |
|
args = { |
|
"do_not_save_samples": True, |
|
"do_not_save_grid": True, |
|
"outpath_samples": "/content/desktop", |
|
"prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer", |
|
"negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed", |
|
"sampler_name": "DPM++ SDE Karras", |
|
"steps": 20, |
|
"cfg_scale": 8, |
|
"width": 512, |
|
"height": 768, |
|
"seed": -1, |
|
} |
|
print("&&&&&&&&&&&&&&&&&&&&&&&&",request) |
|
if request.prompt: |
|
prompt = request.prompt |
|
print("get prompt from request: ", prompt) |
|
args["prompt"] = prompt |
|
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) |
|
processed = process_images(p) |
|
single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8') |
|
return { |
|
"img_data": single_image_b64, |
|
"parameters": processed.images[0].info.get('parameters', ""), |
|
} |
|
|
|
|
|
component = PyTorchServer( |
|
cloud_compute=L.CloudCompute('gpu', disk_size=20, idle_timeout=30) |
|
) |
|
|
|
app = L.LightningApp(component) |
|
|
|
|