ViewDiffusion / inference.py
BertChristiaens's picture
Update inference.py
6db47b5
raw
history blame
1.23 kB
import streamlit as st
import torch
import numpy
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionInpaintPipeline
from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
@torch.inference_mode()
@st.cache_resource
def get_pipeline():
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16)
pipe.to(device)
pipe.enable_xformers_memory_efficient_attention()
pipe.set_progress_bar_config(disable=True)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
return pipe
def inpainting(image,
mask_image,
prompt,
negative_prompt,
num_inference_steps=20,
guidance_scale=7.5,
):
pipe = get_pipeline()
result = pipe(
image=image,
mask_image=mask_image,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
print("Generated image")
return result