pr4nav101's picture
Update app.py
9be40eb verified
from diffusers import StableDiffusionInpaintPipeline
import torch
model_id = 'stabilityai/stable-diffusion-2-inpainting'
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(model_id,torch_dtype = torch.float16)
sd_pipeline = sd_pipeline.to("cuda")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to("cuda")
predictor = SamPredictor(sam)
import gradio as gr
import numpy as np
from PIL import Image
selected_pixels = []
isInvert = 0
with gr.Blocks() as genaieg:
selected_pixels = []
isInvert = 0
with gr.Row():
input_img = gr.Image(label = 'Input')
mask_img = gr.Image(label = "Mask")
with gr.Row():
output_img = gr.Image(label = "Ouput")
def invertmask():
global isInvert
isInvert = not(isInvert)
with gr.Row():
prompt_text = gr.Textbox(line = 1,label = 'Prompt')
submit = gr.Button('Submit')
radio = gr.Radio(['Invert Mask'])
radio.select(fn = invertmask)
def generate_mask(image, evt: gr.SelectData):
selected_pixels.append(evt.index)
predictor.set_image(image)
input_points = np.array(selected_pixels)
input_label = np.ones(input_points.shape[0])
mask, _, _ = predictor.predict(
point_coords = input_points,
point_labels = input_label,
multimask_output = False
)
if isInvert:
mask = np.logical_not(mask)
mask = Image.fromarray(mask[0,:,:])
return mask
def inpaint(img, mask, prompt):
img = Image.fromarray(img)
mask = Image.fromarray(mask)
img = img.resize((512,512))
mask = mask.resize((512,512))
negative_prompts = """
duplicate,low quality, lowest quality, bad shape,bad anatomy,
bad proportions, lowres,error,watermark,username,artistname,
signature,text,jpeg artifacts,blurry,more than one person,simple background
"""
prompt_text = "Realistic professinal Headshot of a man for a profile pic" + prompt
output = sd_pipeline(prompt = prompt_text,
image = img,
negative_prompt = negative_prompts,
mask_image = mask).images[0]
return output
input_img.select(generate_mask, [input_img],[mask_img])
submit.click(inpaint,
inputs=[input_img,mask_img,prompt_text],
outputs = [output_img])
if __name__ == '__main__':
genaieg.launch(debug = True)