ViewDiffusion / app.py
nigeljw's picture
Added safety checker back
f85d443
raw
history blame
3.51 kB
import gradio
import torch
import numpy
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionInpaintPipeline
from diffusers import DPMSolverMultistepScheduler
deviceStr = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(deviceStr)
latents = None
def GenerateNewLatentsForInference():
global latents
if deviceStr == "cuda":
latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16)
else:
latents = torch.randn((1, 4, 64, 64), device=device)
if deviceStr == "cuda":
pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16)
#safety_checker=lambda images, **kwargs: (images, False))
pipeline.to(device)
pipeline.enable_xformers_memory_efficient_attention()
else:
pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
#safety_checker=lambda images, **kwargs: (images, False))
GenerateNewLatentsForInference()
imageSize = (512, 512)
lastImage = Image.new(mode="RGB", size=imageSize)
lastSeed = 512
generator = torch.Generator(device).manual_seed(512)
def diffuse(staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed):
global latents, lastSeed, generator, deviceStr, lastImage
if mask is None or pauseInference is True:
return lastImage
if staticLatents is False:
GenerateNewLatentsForInference()
if lastSeed != seed:
generator = torch.Generator(device).manual_seed(seed)
lastSeed = seed
newImage = pipeline(prompt=prompt,
negative_prompt=negativePrompt,
image=inputImage,
mask_image=mask,
guidance_scale=guidanceScale,
num_inference_steps=numInferenceSteps,
latents=latents,
generator=generator).images[0]
lastImage = newImage
return newImage
defaultMask = Image.open("assets/masks/sphere.png")
prompt = gradio.Textbox(label="Prompt", placeholder="A person in a room", lines=3)
negativePrompt = gradio.Textbox(label="Negative Prompt", placeholder="Text", lines=3)
inputImage = gradio.Image(label="Input Feed", source="webcam", shape=[512,512], streaming=True)
mask = gradio.Image(label="Mask", type="pil", value=defaultMask)
outputImage = gradio.Image(label="Extrapolated Field of View")
guidanceScale = gradio.Slider(label="Guidance Scale", maximum=1, value=0.75)
numInferenceSteps = gradio.Slider(label="Number of Inference Steps", maximum=100, value=25)
seed = gradio.Slider(label="Generator Seed", maximum=10000, value=4096)
staticLatents =gradio.Checkbox(label="Static Latents", value=True)
pauseInference = gradio.Checkbox(label="Pause Inference", value=False)
#generateNewLatents = gradio.Button(label="Generate New Latents")
#generateNewLatents.click(GenerateNewLatentsForInference)
inputs=[staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed]
ux = gradio.Interface(fn=diffuse, title="View Diffusion", inputs=inputs, outputs=outputImage, live=True)
ux.launch()