|
import base64 |
|
import io |
|
import re |
|
import time |
|
from datetime import date |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import requests |
|
import torch |
|
from PIL import Image |
|
|
|
from modules import shared |
|
from modules.models import reload_model, unload_model |
|
from modules.ui import create_refresh_button |
|
|
|
torch._C._jit_set_profiling_mode(False) |
|
|
|
|
|
params = { |
|
'address': 'http://127.0.0.1:7860', |
|
'mode': 0, |
|
'manage_VRAM': False, |
|
'save_img': False, |
|
'SD_model': 'NeverEndingDream', |
|
'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful', |
|
'negative_prompt': '(worst quality, low quality:1.3)', |
|
'width': 512, |
|
'height': 512, |
|
'denoising_strength': 0.61, |
|
'restore_faces': False, |
|
'enable_hr': False, |
|
'hr_upscaler': 'ESRGAN_4x', |
|
'hr_scale': '1.0', |
|
'seed': -1, |
|
'sampler_name': 'DPM++ 2M Karras', |
|
'steps': 32, |
|
'cfg_scale': 7, |
|
'textgen_prefix': 'Please provide a detailed and vivid description of [subject]', |
|
'sd_checkpoint': ' ', |
|
'checkpoint_list': [" "] |
|
} |
|
|
|
|
|
def give_VRAM_priority(actor): |
|
global shared, params |
|
|
|
if actor == 'SD': |
|
unload_model() |
|
print("Requesting Auto1111 to re-load last checkpoint used...") |
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') |
|
response.raise_for_status() |
|
|
|
elif actor == 'LLM': |
|
print("Requesting Auto1111 to vacate VRAM...") |
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') |
|
response.raise_for_status() |
|
reload_model() |
|
|
|
elif actor == 'set': |
|
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...") |
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='') |
|
response.raise_for_status() |
|
|
|
elif actor == 'reset': |
|
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint") |
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='') |
|
response.raise_for_status() |
|
|
|
else: |
|
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!') |
|
|
|
response.raise_for_status() |
|
del response |
|
|
|
|
|
if params['manage_VRAM']: |
|
give_VRAM_priority('set') |
|
|
|
SD_models = ['NeverEndingDream'] |
|
|
|
picture_response = False |
|
|
|
|
|
def remove_surrounded_chars(string): |
|
|
|
|
|
return re.sub('\*[^\*]*?(\*|$)', '', string) |
|
|
|
|
|
def triggers_are_in(string): |
|
string = remove_surrounded_chars(string) |
|
|
|
|
|
|
|
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string)) |
|
|
|
|
|
def state_modifier(state): |
|
if picture_response: |
|
state['stream'] = False |
|
|
|
return state |
|
|
|
|
|
def input_modifier(string): |
|
""" |
|
This function is applied to your text inputs before |
|
they are fed into the model. |
|
""" |
|
|
|
global params |
|
|
|
if not params['mode'] == 1: |
|
return string |
|
|
|
if triggers_are_in(string): |
|
toggle_generation(True) |
|
string = string.lower() |
|
if "of" in string: |
|
subject = string.split('of', 1)[1] |
|
string = params['textgen_prefix'].replace("[subject]", subject) |
|
else: |
|
string = params['textgen_prefix'].replace("[subject]", "your appearance, your surroundings and what you are doing right now") |
|
|
|
return string |
|
|
|
|
|
def get_SD_pictures(description, character): |
|
|
|
global params |
|
|
|
if params['manage_VRAM']: |
|
give_VRAM_priority('SD') |
|
|
|
description = re.sub('<audio.*?</audio>', ' ', description) |
|
description = f"({description}:1)" |
|
|
|
payload = { |
|
"prompt": params['prompt_prefix'] + description, |
|
"seed": params['seed'], |
|
"sampler_name": params['sampler_name'], |
|
"enable_hr": params['enable_hr'], |
|
"hr_scale": params['hr_scale'], |
|
"hr_upscaler": params['hr_upscaler'], |
|
"denoising_strength": params['denoising_strength'], |
|
"steps": params['steps'], |
|
"cfg_scale": params['cfg_scale'], |
|
"width": params['width'], |
|
"height": params['height'], |
|
"restore_faces": params['restore_faces'], |
|
"override_settings_restore_afterwards": True, |
|
"negative_prompt": params['negative_prompt'] |
|
} |
|
|
|
print(f'Prompting the image generator via the API on {params["address"]}...') |
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) |
|
response.raise_for_status() |
|
r = response.json() |
|
|
|
visible_result = "" |
|
for img_str in r['images']: |
|
if params['save_img']: |
|
img_data = base64.b64decode(img_str) |
|
|
|
variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}' |
|
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png') |
|
output_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
with open(output_file.as_posix(), 'wb') as f: |
|
f.write(img_data) |
|
|
|
visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n' |
|
else: |
|
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) |
|
|
|
image.thumbnail((300, 300)) |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
buffered.seek(0) |
|
image_bytes = buffered.getvalue() |
|
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() |
|
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n' |
|
|
|
if params['manage_VRAM']: |
|
give_VRAM_priority('LLM') |
|
|
|
return visible_result |
|
|
|
|
|
|
|
def output_modifier(string, state): |
|
""" |
|
This function is applied to the model outputs. |
|
""" |
|
|
|
global picture_response, params |
|
|
|
if not picture_response: |
|
return string |
|
|
|
string = remove_surrounded_chars(string) |
|
string = string.replace('"', '') |
|
string = string.replace('“', '') |
|
string = string.replace('\n', ' ') |
|
string = string.strip() |
|
|
|
if string == '': |
|
string = 'no viable description in reply, try regenerating' |
|
return string |
|
|
|
text = "" |
|
if (params['mode'] < 2): |
|
toggle_generation(False) |
|
text = f'*Sends a picture which portrays: “{string}”*' |
|
else: |
|
text = string |
|
|
|
string = get_SD_pictures(string, state['character_menu']) + "\n" + text |
|
|
|
return string |
|
|
|
|
|
def bot_prefix_modifier(string): |
|
""" |
|
This function is only applied in chat mode. It modifies |
|
the prefix text for the Bot and can be used to bias its |
|
behavior. |
|
""" |
|
|
|
return string |
|
|
|
|
|
def toggle_generation(*args): |
|
global picture_response, shared |
|
|
|
if not args: |
|
picture_response = not picture_response |
|
else: |
|
picture_response = args[0] |
|
|
|
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*" |
|
|
|
|
|
def filter_address(address): |
|
address = address.strip() |
|
|
|
address = re.sub('\/$', '', address) |
|
if not address.startswith('http'): |
|
address = 'http://' + address |
|
return address |
|
|
|
|
|
def SD_api_address_update(address): |
|
global params |
|
|
|
msg = "✔️ SD API is found on:" |
|
address = filter_address(address) |
|
params.update({"address": address}) |
|
try: |
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') |
|
response.raise_for_status() |
|
|
|
except: |
|
msg = "❌ No SD API endpoint on:" |
|
|
|
return gr.Textbox.update(label=msg) |
|
|
|
|
|
def custom_css(): |
|
path_to_css = Path(__file__).parent.resolve() / 'style.css' |
|
return open(path_to_css, 'r').read() |
|
|
|
|
|
def get_checkpoints(): |
|
global params |
|
|
|
try: |
|
models = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models') |
|
options = requests.get(url=f'{params["address"]}/sdapi/v1/options') |
|
options_json = options.json() |
|
params['sd_checkpoint'] = options_json['sd_model_checkpoint'] |
|
params['checkpoint_list'] = [result["title"] for result in models.json()] |
|
except: |
|
params['sd_checkpoint'] = "" |
|
params['checkpoint_list'] = [] |
|
|
|
return gr.update(choices=params['checkpoint_list'], value=params['sd_checkpoint']) |
|
|
|
|
|
def load_checkpoint(checkpoint): |
|
payload = { |
|
"sd_model_checkpoint": checkpoint |
|
} |
|
|
|
try: |
|
requests.post(url=f'{params["address"]}/sdapi/v1/options', json=payload) |
|
except: |
|
pass |
|
|
|
|
|
def get_samplers(): |
|
try: |
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/samplers') |
|
response.raise_for_status() |
|
samplers = [x["name"] for x in response.json()] |
|
except: |
|
samplers = [] |
|
|
|
return samplers |
|
|
|
|
|
def ui(): |
|
|
|
|
|
|
|
with gr.Accordion("Parameters", open=True, elem_classes="SDAP"): |
|
with gr.Row(): |
|
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address') |
|
modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"] |
|
mode = gr.Dropdown(modes_list, value=modes_list[params['mode']], label="Mode of operation", type="index") |
|
with gr.Column(scale=1, min_width=300): |
|
manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM') |
|
save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat') |
|
|
|
force_pic = gr.Button("Force the picture response") |
|
suppr_pic = gr.Button("Suppress the picture response") |
|
with gr.Row(): |
|
checkpoint = gr.Dropdown(params['checkpoint_list'], value=params['sd_checkpoint'], label="Checkpoint", type="value") |
|
update_checkpoints = gr.Button("Get list of checkpoints") |
|
|
|
with gr.Accordion("Generation parameters", open=False): |
|
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') |
|
textgen_prefix = gr.Textbox(placeholder=params['textgen_prefix'], value=params['textgen_prefix'], label='textgen prefix (type [subject] where the subject should be placed)') |
|
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') |
|
with gr.Row(): |
|
with gr.Column(): |
|
width = gr.Slider(64, 2048, value=params['width'], step=64, label='Width') |
|
height = gr.Slider(64, 2048, value=params['height'], step=64, label='Height') |
|
with gr.Column(variant="compact", elem_id="sampler_col"): |
|
with gr.Row(elem_id="sampler_row"): |
|
sampler_name = gr.Dropdown(value=params['sampler_name'], allow_custom_value=True, label='Sampling method', elem_id="sampler_box") |
|
create_refresh_button(sampler_name, lambda: None, lambda: {'choices': get_samplers()}, 'refresh-button') |
|
steps = gr.Slider(1, 150, value=params['steps'], step=1, label="Sampling steps", elem_id="steps_box") |
|
with gr.Row(): |
|
seed = gr.Number(label="Seed", value=params['seed'], elem_id="seed_box") |
|
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box") |
|
with gr.Column() as hr_options: |
|
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces') |
|
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') |
|
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options: |
|
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') |
|
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') |
|
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') |
|
|
|
|
|
address.change(lambda x: params.update({"address": filter_address(x)}), address, None) |
|
mode.select(lambda x: params.update({"mode": x}), mode, None) |
|
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None) |
|
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None) |
|
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None) |
|
save_img.change(lambda x: params.update({"save_img": x}), save_img, None) |
|
|
|
address.submit(fn=SD_api_address_update, inputs=address, outputs=address) |
|
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None) |
|
textgen_prefix.change(lambda x: params.update({"textgen_prefix": x}), textgen_prefix, None) |
|
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None) |
|
width.change(lambda x: params.update({"width": x}), width, None) |
|
height.change(lambda x: params.update({"height": x}), height, None) |
|
hr_scale.change(lambda x: params.update({"hr_scale": x}), hr_scale, None) |
|
denoising_strength.change(lambda x: params.update({"denoising_strength": x}), denoising_strength, None) |
|
restore_faces.change(lambda x: params.update({"restore_faces": x}), restore_faces, None) |
|
hr_upscaler.change(lambda x: params.update({"hr_upscaler": x}), hr_upscaler, None) |
|
enable_hr.change(lambda x: params.update({"enable_hr": x}), enable_hr, None) |
|
enable_hr.change(lambda x: hr_options.update(visible=params["enable_hr"]), enable_hr, hr_options) |
|
update_checkpoints.click(get_checkpoints, None, checkpoint) |
|
checkpoint.change(lambda x: params.update({"sd_checkpoint": x}), checkpoint, None) |
|
checkpoint.change(load_checkpoint, checkpoint, None) |
|
|
|
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) |
|
steps.change(lambda x: params.update({"steps": x}), steps, None) |
|
seed.change(lambda x: params.update({"seed": x}), seed, None) |
|
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None) |
|
|
|
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None) |
|
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None) |
|
|