ViewDiffusion / inference.py
BertChristiaens's picture
Update inference.py
44e42ee
raw
history blame
1.2 kB
import streamlit as st
import torch
import numpy
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionInpaintPipeline
from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
@torch.inference_mode()
@st.cache_resource
def get_pipeline():
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16)
pipe.to(device)
pipe.enable_xformers_memory_efficient_attention()
pipe.set_progress_bar_config(disable=True)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
return pipe
def inpainting(image,
mask_image,
prompt,
negative_prompt,
num_inference_steps=20,
guidance_scale=7.5,
):
pipe = get_pipeline()
result = pipe(
image=image,
mask_image=mask_image,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
return result