|
import io |
|
import sys |
|
|
|
import fastapi |
|
import gradio as gr |
|
from pydantic import BaseModel, Field |
|
|
|
import scripts.shared as shared |
|
from scripts.utilities import run_python |
|
|
|
proc = None |
|
outputs = [] |
|
|
|
|
|
def alive(): |
|
return proc is not None |
|
|
|
|
|
def initialize_runner(script_file, tmpls, opts): |
|
run_button = gr.Button( |
|
"Run", |
|
variant="primary", |
|
elem_id=f"kohya_sd_webui__{shared.current_tab}_run_button", |
|
) |
|
stop_button = gr.Button( |
|
"Stop", |
|
variant="secondary", |
|
elem_id=f"kohya_sd_webui__{shared.current_tab}_stop_button", |
|
) |
|
get_templates = lambda: tmpls() if callable(tmpls) else tmpls |
|
get_options = lambda: opts() if callable(opts) else opts |
|
|
|
def run(args): |
|
global proc |
|
global outputs |
|
if alive(): |
|
return |
|
proc = run_python(script_file, get_templates(), get_options(), args) |
|
reader = io.TextIOWrapper(proc.stdout, encoding="utf-8-sig") |
|
line = "" |
|
while proc is not None and proc.poll() is None: |
|
try: |
|
char = reader.read(1) |
|
if shared.cmd_opts.enable_console_log: |
|
sys.stdout.write(char) |
|
if char == "\n": |
|
outputs.append(line) |
|
line = "" |
|
continue |
|
line += char |
|
except: |
|
() |
|
proc = None |
|
|
|
def stop(): |
|
global proc |
|
print("killed the running process") |
|
proc.kill() |
|
proc = None |
|
|
|
def init(): |
|
run_button.click( |
|
run, |
|
set(get_options().values()), |
|
) |
|
stop_button.click(stop) |
|
|
|
return init |
|
|
|
|
|
class GetOutputRequest(BaseModel): |
|
output_index: int = Field( |
|
default=0, title="Index of the beginning of the log to retrieve" |
|
) |
|
clear_terminal: bool = Field( |
|
default=False, title="Whether to clear the terminal" |
|
) |
|
|
|
|
|
class GetOutputResponse(BaseModel): |
|
outputs: list = Field(title="List of terminal output") |
|
|
|
|
|
class ProcessAliveResponse(BaseModel): |
|
alive: bool = Field(title="Whether the process is running.") |
|
|
|
|
|
def api_get_outputs(req: GetOutputRequest): |
|
i = req.output_index |
|
if req.clear_terminal: |
|
global outputs |
|
outputs = [] |
|
out = outputs[i:] if len(outputs) > i else [] |
|
return GetOutputResponse(outputs=out) |
|
|
|
|
|
def api_get_isalive(req: fastapi.Request): |
|
return ProcessAliveResponse(alive=alive()) |
|
|
|
|
|
def initialize_api(app: fastapi.FastAPI): |
|
app.add_api_route( |
|
"/internal/extensions/kohya-sd-scripts-webui/terminal/outputs", |
|
api_get_outputs, |
|
methods=["POST"], |
|
response_model=GetOutputResponse, |
|
) |
|
app.add_api_route( |
|
"/internal/extensions/kohya-sd-scripts-webui/process/alive", |
|
api_get_isalive, |
|
methods=["GET"], |
|
response_model=ProcessAliveResponse, |
|
) |
|
|