felixrosberg's picture
Update app.py
95fbdbe verified
raw
history blame
2.8 kB
import torch
import numpy as np
import gradio as gr
import spaces
import cv2
import os
from typing import Dict
from PIL import Image
from huggingface_hub import Repository
token = os.environ['model_fetch']
engine_repo = Repository(local_dir="engine", clone_from="felixrosberg/EngageDiffusion", use_auth_token=token)
from engine.ui_model import fetch_model, run_model
from engine.ui_gradio import fetch_ui
engage_logo = Image.open("engage_studios_logo.png").resize((700, 88), Image.Resampling.BICUBIC)
engage_logo_mask = np.clip(np.array(engage_logo.split()[-1])[..., None] / 255, 0, 1)
engage_logo_np = np.array(engage_logo.convert('RGB'))
pipe = fetch_model()
pipe.to('cuda')
@spaces.GPU
def inference(user_state, condition_image, settings, prompt, neg_prompt, inference_steps=8, num_images=2,
guidance_scale=2.0,
guidance_rescale=0.0,
enable_freeu=False,
height=1024,
width=1024,
condition_scale=0.5,
sketch_detail=1.0,
sketch_softness=0.5,
inpaint_strength=0.9,
exposure=0.0,
enable_stylation=False,
style_1_down=0.0,
style_1_mid=0.0,
style_1_up=0.0,
style_2_down=0.0,
style_2_mid=0.0,
style_2_up=0.0,
style_3_down=0.0,
style_3_mid=0.0,
style_3_up=0.0,
style_4_down=0.0,
style_4_mid=0.0,
style_4_up=0.0,
seed=None,
progress=gr.Progress()):
images = run_model(pipe, user_state, condition_image, settings, prompt, neg_prompt, inference_steps, num_images,
guidance_scale,
guidance_rescale,
enable_freeu,
height,
width,
condition_scale,
sketch_detail,
sketch_softness,
inpaint_strength,
exposure,
enable_stylation,
style_1_down,
style_1_mid,
style_1_up,
style_2_down,
style_2_mid,
style_2_up,
style_3_down,
style_3_mid,
style_3_up,
style_4_down,
style_4_mid,
style_4_up,
seed,
progress)
for idx, im in enumerate(images):
im = np.asarray(im).copy()
im[-88:, :700] = im[-88:, :700] * (1 - engage_logo_mask) + engage_logo_np
images[idx] = Image.fromarray(np.clip(im.astype('uint8'), 0, 255))
user_state["IMAGE_GALLERY"] += images
return user_state["IMAGE_GALLERY"], user_state
engage_demo = fetch_ui(inference)
engage_demo.launch()