|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import atexit |
|
import json |
|
import random |
|
import tempfile |
|
import traceback |
|
import logging |
|
import threading |
|
from concurrent.futures import ThreadPoolExecutor |
|
from dataclasses import dataclass, replace |
|
from datetime import datetime |
|
from typing import Any, Dict, List |
|
from deforum_api_models import Batch, DeforumJobErrorType, DeforumJobStatusCategory, DeforumJobPhase, DeforumJobStatus |
|
from contextlib import contextmanager |
|
from deforum_extend_paths import deforum_sys_extend |
|
|
|
import gradio as gr |
|
from deforum_helpers.args import (DeforumAnimArgs, DeforumArgs, |
|
DeforumOutputArgs, LoopArgs, ParseqArgs, |
|
RootArgs, get_component_names) |
|
from fastapi import FastAPI, Response, status |
|
|
|
from modules.shared import cmd_opts, opts, state |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
log_level = os.environ.get("DEFORUM_API_LOG_LEVEL") or os.environ.get("SD_WEBUI_LOG_LEVEL") or "INFO" |
|
log.setLevel(log_level) |
|
logging.basicConfig( |
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
) |
|
|
|
def make_ids(job_count: int): |
|
batch_id = f"batch({random.randint(0, 1e9)})" |
|
job_ids = [f"{batch_id}-{i}" for i in range(job_count)] |
|
return [batch_id, job_ids] |
|
|
|
|
|
def get_default_value(name:str): |
|
allArgs = RootArgs() | DeforumAnimArgs() | DeforumArgs() | LoopArgs() | ParseqArgs() | DeforumOutputArgs() |
|
if name in allArgs and isinstance(allArgs[name], dict): |
|
return allArgs[name].get("value", None) |
|
elif name in allArgs: |
|
return allArgs[name] |
|
else: |
|
return None |
|
|
|
|
|
def run_deforum_batch(batch_id: str, job_ids: [str], deforum_settings_files: List[Any], opts_overrides: Dict[str, Any] = None): |
|
log.info(f"Starting batch {batch_id} in thread {threading.get_ident()}.") |
|
try: |
|
with A1111OptionsOverrider(opts_overrides): |
|
|
|
|
|
component_names = get_component_names() |
|
prefixed_gradio_args = 2 |
|
expected_arg_count = prefixed_gradio_args + len(component_names) |
|
run_deforum_args = [None] * expected_arg_count |
|
for idx, name in enumerate(component_names): |
|
run_deforum_args[prefixed_gradio_args + idx] = get_default_value(name) |
|
|
|
|
|
run_deforum_args[prefixed_gradio_args + component_names.index('animation_prompts')] = '{"0":"dummy value"}' |
|
run_deforum_args[prefixed_gradio_args + component_names.index('animation_prompts_negative')] = '' |
|
|
|
|
|
run_deforum_args[0] = batch_id |
|
|
|
|
|
run_deforum_args[prefixed_gradio_args + component_names.index('override_settings_with_file')] = True |
|
run_deforum_args[prefixed_gradio_args + component_names.index('custom_settings_file')] = deforum_settings_files |
|
|
|
|
|
|
|
state.skipped = False |
|
state.interrupted = False |
|
|
|
|
|
from deforum_helpers.run_deforum import run_deforum |
|
run_deforum(*run_deforum_args) |
|
|
|
except Exception as e: |
|
log.error(f"Batch {batch_id} failed: {e}") |
|
traceback.print_exc() |
|
for job_id in job_ids: |
|
|
|
JobStatusTracker().fail_job(job_id, 'TERMINAL', {e}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deforum_api(_: gr.Blocks, app: FastAPI): |
|
|
|
deforum_sys_extend() |
|
|
|
apiState = ApiState() |
|
|
|
|
|
@app.post("/deforum_api/batches") |
|
async def run_batch(batch: Batch, response: Response): |
|
|
|
|
|
deforum_settings_data = batch.deforum_settings |
|
if not deforum_settings_data: |
|
response.status_code = status.HTTP_400_BAD_REQUEST |
|
return {"message": "No settings files provided. Please provide an element 'deforum_settings' of type list in the request JSON payload."} |
|
|
|
if not isinstance(deforum_settings_data, list): |
|
|
|
deforum_settings_data = [deforum_settings_data] |
|
|
|
deforum_settings_tempfiles = [] |
|
for data in deforum_settings_data: |
|
temp_file = tempfile.NamedTemporaryFile(mode='w+t', delete=False) |
|
json.dump(data, temp_file) |
|
temp_file.close() |
|
deforum_settings_tempfiles.append(temp_file) |
|
|
|
job_count = len(deforum_settings_tempfiles) |
|
[batch_id, job_ids] = make_ids(job_count) |
|
apiState.submit_job(batch_id, job_ids, deforum_settings_tempfiles, batch.options_overrides) |
|
|
|
for idx, job_id in enumerate(job_ids): |
|
JobStatusTracker().accept_job(batch_id=batch_id, job_id=job_id, deforum_settings=deforum_settings_data[idx], options_overrides=batch.options_overrides) |
|
|
|
response.status_code = status.HTTP_202_ACCEPTED |
|
return {"message": "Job(s) accepted", "batch_id": batch_id, "job_ids": job_ids } |
|
|
|
|
|
@app.get("/deforum_api/batches") |
|
async def list_batches(id: str): |
|
return JobStatusTracker().batches |
|
|
|
|
|
@app.get("/deforum_api/batches/{id}") |
|
async def get_batch(id: str, response: Response): |
|
jobsForBatch = JobStatusTracker().batches[id] |
|
if not jobsForBatch: |
|
response.status_code = status.HTTP_404_NOT_FOUND |
|
return {"id": id, "status": "NOT FOUND"} |
|
return [JobStatusTracker().get(job_id) for job_id in jobsForBatch] |
|
|
|
|
|
@app.delete("/deforum_api/batches/{id}") |
|
async def cancel_batch(id: str, response: Response): |
|
jobsForBatch = JobStatusTracker().batches[id] |
|
cancelled_jobs = [] |
|
if not jobsForBatch: |
|
response.status_code = status.HTTP_404_NOT_FOUND |
|
return {"id": id, "status": "NOT FOUND"} |
|
for job_id in jobsForBatch: |
|
try: |
|
cancelled = _cancel_job(job_id) |
|
if cancelled: |
|
cancelled_jobs.append(job_id) |
|
except: |
|
log.warning(f"Failed to cancel job {job_id}") |
|
|
|
return {"ids": cancelled_jobs, "message:": f"{len(cancelled_jobs)} job(s) cancelled." } |
|
|
|
|
|
@app.get("/deforum_api/jobs") |
|
async def list_jobs(): |
|
return JobStatusTracker().statuses |
|
|
|
|
|
@app.get("/deforum_api/jobs/{id}") |
|
async def get_job(id: str, response: Response): |
|
jobStatus = JobStatusTracker().get(id) |
|
if not jobStatus: |
|
response.status_code = status.HTTP_404_NOT_FOUND |
|
return {"id": id, "status": "NOT FOUND"} |
|
return jobStatus |
|
|
|
|
|
@app.delete("/deforum_api/jobs/{id}") |
|
async def cancel_job(id: str, response: Response): |
|
try: |
|
if _cancel_job(id): |
|
return {"id": id, "message": "Job cancelled."} |
|
else: |
|
response.status_code = status.HTTP_400_BAD_REQUEST |
|
return {"id": id, "message": f"Job with ID {id} not in a cancellable state. Has it already finished?"} |
|
except FileNotFoundError as e: |
|
response.status_code = status.HTTP_404_NOT_FOUND |
|
return {"id": id, "message": f"Job with ID {id} not found."} |
|
|
|
|
|
def _cancel_job(job_id:str): |
|
jobStatus = JobStatusTracker().get(job_id) |
|
if not jobStatus: |
|
raise FileNotFoundError(f"Job {job_id} not found.") |
|
|
|
if jobStatus.status != DeforumJobStatusCategory.ACCEPTED: |
|
|
|
return False |
|
|
|
if job_id in ApiState().submitted_jobs: |
|
|
|
ApiState().submitted_jobs[job_id].cancel() |
|
if jobStatus.phase != DeforumJobPhase.QUEUED and jobStatus.phase != DeforumJobPhase.DONE: |
|
|
|
|
|
|
|
|
|
state.interrupt() |
|
JobStatusTracker().cancel_job(job_id, "Cancelled due to user request.") |
|
return True |
|
|
|
class Singleton(type): |
|
_instances = {} |
|
def __call__(cls, *args, **kwargs): |
|
if cls not in cls._instances: |
|
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) |
|
return cls._instances[cls] |
|
|
|
|
|
class ApiState(metaclass=Singleton): |
|
|
|
|
|
|
|
deforum_api_executor = ThreadPoolExecutor(max_workers=1) |
|
submitted_jobs : Dict[str, Any] = {} |
|
|
|
@staticmethod |
|
def cleanup(): |
|
ApiState().deforum_api_executor.shutdown(wait=False) |
|
|
|
def submit_job(self, batch_id: str, job_ids: [str], deforum_settings: List[Any], opts_overrides: Dict[str, Any]): |
|
log.debug(f"Submitting batch {batch_id} to threadpool.") |
|
future = self.deforum_api_executor.submit(lambda: run_deforum_batch(batch_id, job_ids, deforum_settings, opts_overrides)) |
|
self.submitted_jobs[batch_id] = future |
|
|
|
atexit.register(ApiState.cleanup) |
|
|
|
|
|
class A1111OptionsOverrider(object): |
|
def __init__(self, opts_overrides: Dict[str, Any]): |
|
self.opts_overrides = opts_overrides |
|
|
|
def __enter__(self): |
|
if self.opts_overrides is not None and len(self.opts_overrides)>1: |
|
self.original_opts = {k: opts.data[k] for k in self.opts_overrides.keys() if k in opts.data} |
|
log.debug(f"Captured options to override: {self.original_opts}") |
|
log.info(f"Setting options: {self.opts_overrides}") |
|
for k, v in self.opts_overrides.items(): |
|
setattr(opts, k, v) |
|
else: |
|
self.original_opts = None |
|
return self |
|
|
|
def __exit__(self, exception_type, exception_value, traceback): |
|
if (exception_type is not None): |
|
log.warning(f"Error during batch execution: {exception_type} - {exception_value}") |
|
log.debug(f"{traceback}") |
|
if (self.original_opts is not None): |
|
log.info(f"Restoring options: {self.original_opts}") |
|
for k, v in self.original_opts.items(): |
|
setattr(opts, k, v) |
|
|
|
|
|
|
|
|
|
class JobStatusTracker(metaclass=Singleton): |
|
statuses: Dict[str, DeforumJobStatus] = {} |
|
batches: Dict[str, List[str]] = {} |
|
|
|
def accept_job(self, batch_id : str, job_id: str, deforum_settings : List[Dict[str, Any]] , options_overrides : Dict[str, Any]): |
|
if batch_id in self.batches: |
|
self.batches[batch_id].append(job_id) |
|
else: |
|
self.batches[batch_id] = [job_id] |
|
|
|
now = datetime.now().timestamp() |
|
self.statuses[job_id] = DeforumJobStatus( |
|
id=job_id, |
|
status= DeforumJobStatusCategory.ACCEPTED, |
|
phase=DeforumJobPhase.QUEUED, |
|
error_type=DeforumJobErrorType.NONE, |
|
phase_progress=0.0, |
|
started_at=now, |
|
last_updated=now, |
|
execution_time=0, |
|
update_interval_time=0, |
|
updates=0, |
|
message=None, |
|
outdir=None, |
|
timestring=None, |
|
deforum_settings=deforum_settings, |
|
options_overrides=options_overrides, |
|
) |
|
|
|
def update_phase(self, job_id: str, phase: DeforumJobPhase, progress: float = 0): |
|
if job_id in self.statuses: |
|
current_status = self.statuses[job_id] |
|
now = datetime.now().timestamp() |
|
new_status = replace( |
|
current_status, |
|
phase=phase, |
|
phase_progress=progress, |
|
last_updated=now, |
|
execution_time=now-current_status.started_at, |
|
update_interval_time=now-current_status.last_updated, |
|
updates=current_status.updates+1 |
|
) |
|
self.statuses[job_id] = new_status |
|
|
|
def update_output_info(self, job_id: str, outdir: str, timestring: str): |
|
if job_id in self.statuses: |
|
current_status = self.statuses[job_id] |
|
now = datetime.now().timestamp() |
|
new_status = replace( |
|
current_status, |
|
outdir=outdir, |
|
timestring=timestring, |
|
last_updated=now, |
|
execution_time=now-current_status.started_at, |
|
update_interval_time=now-current_status.last_updated, |
|
updates=current_status.updates+1 |
|
) |
|
self.statuses[job_id] = new_status |
|
|
|
def complete_job(self, job_id: str): |
|
if job_id in self.statuses: |
|
current_status = self.statuses[job_id] |
|
now = datetime.now().timestamp() |
|
new_status = replace( |
|
current_status, |
|
status=DeforumJobStatusCategory.SUCCEEDED, |
|
phase=DeforumJobPhase.DONE, |
|
phase_progress=1.0, |
|
last_updated=now, |
|
execution_time=now-current_status.started_at, |
|
update_interval_time=now-current_status.last_updated, |
|
updates=current_status.updates+1 |
|
) |
|
self.statuses[job_id] = new_status |
|
|
|
def fail_job(self, job_id: str, error_type: str, message: str): |
|
if job_id in self.statuses: |
|
current_status = self.statuses[job_id] |
|
now = datetime.now().timestamp() |
|
new_status = replace( |
|
current_status, |
|
status=DeforumJobStatusCategory.FAILED, |
|
error_type=error_type, |
|
message=message, |
|
last_updated=now, |
|
execution_time=now-current_status.started_at, |
|
update_interval_time=now-current_status.last_updated, |
|
updates=current_status.updates+1 |
|
) |
|
self.statuses[job_id] = new_status |
|
|
|
def cancel_job(self, job_id: str, message: str): |
|
if job_id in self.statuses: |
|
current_status = self.statuses[job_id] |
|
now = datetime.now().timestamp() |
|
new_status = replace( |
|
current_status, |
|
status=DeforumJobStatusCategory.CANCELLED, |
|
message=message, |
|
last_updated=now, |
|
execution_time=now-current_status.started_at, |
|
update_interval_time=now-current_status.last_updated, |
|
updates=current_status.updates+1 |
|
) |
|
self.statuses[job_id] = new_status |
|
|
|
|
|
def get(self, job_id:str): |
|
return self.statuses[job_id] if job_id in self.statuses else None |
|
|
|
def deforum_init_batch(_: gr.Blocks, app: FastAPI): |
|
deforum_sys_extend() |
|
settings_files = [open(filename, 'r') for filename in cmd_opts.deforum_run_now.split(",")] |
|
[batch_id, job_ids] = make_ids(len(settings_files)) |
|
log.info(f"Starting init batch {batch_id} with job(s) {job_ids}...") |
|
|
|
run_deforum_batch(batch_id, job_ids, settings_files, None) |
|
|
|
if cmd_opts.deforum_terminate_after_run_now: |
|
import os |
|
os._exit(0) |
|
|
|
|
|
def deforum_simple_api(_: gr.Blocks, app: FastAPI): |
|
deforum_sys_extend() |
|
from fastapi.exceptions import RequestValidationError |
|
from fastapi.responses import JSONResponse |
|
from fastapi import FastAPI, Query, Request, UploadFile |
|
from fastapi.encoders import jsonable_encoder |
|
from deforum_helpers.general_utils import get_deforum_version |
|
import uuid, pathlib |
|
|
|
@app.exception_handler(RequestValidationError) |
|
async def validation_exception_handler(request: Request, exc: RequestValidationError): |
|
return JSONResponse( |
|
status_code=422, |
|
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), |
|
) |
|
|
|
@app.get("/deforum/api_version") |
|
async def deforum_api_version(): |
|
return JSONResponse(content={"version": '1.0'}) |
|
|
|
@app.get("/deforum/version") |
|
async def deforum_version(): |
|
return JSONResponse(content={"version": get_deforum_version()}) |
|
|
|
@app.post("/deforum/run") |
|
async def deforum_run(settings_json:str, allowed_params:str = ""): |
|
try: |
|
allowed_params = allowed_params.split(';') |
|
deforum_settings = json.loads(settings_json) |
|
with open(os.path.join(pathlib.Path(__file__).parent.absolute(), 'default_settings.txt'), 'r', encoding='utf-8') as f: |
|
default_settings = json.loads(f.read()) |
|
for k, _ in default_settings.items(): |
|
if k in deforum_settings and k in allowed_params: |
|
default_settings[k] = deforum_settings[k] |
|
deforum_settings = default_settings |
|
run_id = uuid.uuid4().hex |
|
deforum_settings['batch_name'] = run_id |
|
deforum_settings = json.dumps(deforum_settings, indent=4, ensure_ascii=False) |
|
settings_file = f"{run_id}.txt" |
|
with open(settings_file, 'w', encoding='utf-8') as f: |
|
f.write(deforum_settings) |
|
class SettingsWrapper: |
|
def __init__(self, filename): |
|
self.name = filename |
|
[batch_id, job_ids] = make_ids(1) |
|
outdir = os.path.join(os.getcwd(), opts.outdir_samples or opts.outdir_img2img_samples, str(run_id)) |
|
run_deforum_batch(batch_id, job_ids, [SettingsWrapper(settings_file)], None) |
|
return JSONResponse(content={"outdir": outdir}) |
|
except Exception as e: |
|
print(e) |
|
traceback.print_exc() |
|
return JSONResponse(status_code=500, content={"detail": "An error occurred while processing the video."},) |
|
|
|
|
|
try: |
|
import modules.script_callbacks as script_callbacks |
|
if cmd_opts.deforum_api: |
|
script_callbacks.on_app_started(deforum_api) |
|
if cmd_opts.deforum_simple_api: |
|
script_callbacks.on_app_started(deforum_simple_api) |
|
if cmd_opts.deforum_run_now: |
|
script_callbacks.on_app_started(deforum_init_batch) |
|
except: |
|
pass |
|
|