Spaces:
Build error
Build error
import os | |
import random | |
import websocket | |
import uuid | |
import json | |
import urllib.request | |
import urllib.parse | |
import gradio as gr | |
from glob import glob | |
import requests | |
from pathlib import Path | |
import base64 | |
from PIL import Image | |
import time | |
import io | |
server_address = "127.0.0.1:8188" | |
client_id = str(uuid.uuid4()) | |
def queue_prompt(prompt): | |
p = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(p).encode('utf-8') | |
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) | |
return json.loads(urllib.request.urlopen(req).read()) | |
def get_image(filename, subfolder, folder_type): | |
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
url_values = urllib.parse.urlencode(data) | |
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: | |
return response.read() | |
def get_history(prompt_id): | |
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: | |
return json.loads(response.read()) | |
def get_images(ws, prompt): | |
prompt_id = queue_prompt(prompt)['prompt_id'] | |
output_images = {} | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'executing': | |
data = message['data'] | |
if data['node'] is None and data['prompt_id'] == prompt_id: | |
break #Execution is done | |
else: | |
continue #previews are binary data | |
history = get_history(prompt_id)[prompt_id] | |
for o in history['outputs']: | |
for node_id in history['outputs']: | |
node_output = history['outputs'][node_id] | |
if 'images' in node_output: | |
images_output = [] | |
for image in node_output['images']: | |
image_data = get_image(image['filename'], image['subfolder'], image['type']) | |
images_output.append(image_data) | |
output_images[node_id] = images_output | |
return output_images | |
def detect(image): | |
img = Path(image).read_bytes() | |
rsp = requests.post(f'http://cv.bytedance.net/aipet_head_det/run/predict', json={ | |
'data': ['data:image/png;base64,'+ | |
base64.b64encode(img).decode('utf-8'), | |
] | |
}) | |
return rsp.json()['data'][1] | |
def clip_save(img_in,coords,path="img.png"): | |
img = Image.open(img_in) | |
img2 = img.crop((int(coords[0]), int(coords[1]), int(coords[2]), int(coords[3]))) | |
img2.save(path) | |
def load_template(img_in,seed): | |
seed = int(seed) | |
with open(workflow_base,encoding='utf-8') as file: | |
template = json.load(file) | |
template["14"]["inputs"]["image"] = img_in | |
# template["7"]["inputs"]["text"] = animal + templates[style] | |
template["3"]["inputs"]["seed"] = seed if seed > 0 else random.randint(1,1e8) | |
# template["31"]["inputs"]["seed"] = seed if seed > 0 else random.randint(1,1e8) | |
# template["30"]["inputs"]["lora_name"] = loras[style] | |
# template["30"]["inputs"]["strength_model"] = w_lora | |
# template["30"]["inputs"]["strength_clip"] = w_lora | |
# if debug: | |
# print(template["6"]["inputs"]["image"],template["7"]["inputs"]["text"],template["9"]["inputs"]["seed"],template["30"]["inputs"]["lora_name"],template["30"]["inputs"]["strength_model"],template["30"]["inputs"]["strength_clip"]) | |
return template | |
def generate(img_in,seed): | |
seed = int(seed) | |
template = load_template(img_in,seed) | |
ws = websocket.WebSocket() | |
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) | |
images = get_images(ws, template) | |
for node_id in images: | |
for image_data in images[node_id]: | |
image = Image.open(io.BytesIO(image_data)) | |
path_out = dir_cache+"/"+str(time.time()).split('.')[0]+"_"+str(template["3"]["inputs"]["seed"])+".png" | |
image.save(path_out) | |
return image | |
if __name__ == '__main__': | |
workflow_base = "D:/faceID/workflow_api_anime_0306.json" | |
dir_cache = "D:/faceID/cache" | |
seed = -1 | |
# debug = True | |
demo = gr.Interface( | |
fn = generate, | |
inputs = [ | |
gr.Image(type='filepath'), | |
# gr.Textbox(label="自定义品种",value="", info="自定义品种,内部调试使用"), | |
# gr.Radio(["发财麻将","东北大花","情人玫瑰","天使丘比特","爱心丘比特","美式证件照","新年工笔画","新年唐装","新年糖葫芦","宠物礼盒","生日快乐","雪地工笔画","破壳纪念","爱读书的学霸","米其林大厨","疯狂赛车手","工笔画","圣诞树","圣诞雪人","圣诞老人",], label="风格", info="更多风格规划中,敬请期待~"), | |
# gr.Slider(0, 1, value=0.5,step=0.05,label='风格化程度',info='推荐值:低风格化0.3, 中风格化0.5, 高风格化0.7'), | |
gr.Textbox(label="随机种子",value=-1, info="-1为随机种子,大于0时为自定义种子") | |
], | |
outputs = ["image"] | |
) | |
demo.queue(max_size=2) | |
demo.launch(share=True) | |