from diffusers import StableDiffusionInpaintPipeline import torch model_id = 'stabilityai/stable-diffusion-2-inpainting' sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(model_id,torch_dtype = torch.float16) sd_pipeline = sd_pipeline.to("cuda") from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to("cuda") predictor = SamPredictor(sam) import gradio as gr import numpy as np from PIL import Image selected_pixels = [] isInvert = 0 with gr.Blocks() as genaieg: selected_pixels = [] isInvert = 0 with gr.Row(): input_img = gr.Image(label = 'Input') mask_img = gr.Image(label = "Mask") with gr.Row(): output_img = gr.Image(label = "Ouput") def invertmask(): global isInvert isInvert = not(isInvert) with gr.Row(): prompt_text = gr.Textbox(line = 1,label = 'Prompt') submit = gr.Button('Submit') radio = gr.Radio(['Invert Mask']) radio.select(fn = invertmask) def generate_mask(image, evt: gr.SelectData): selected_pixels.append(evt.index) predictor.set_image(image) input_points = np.array(selected_pixels) input_label = np.ones(input_points.shape[0]) mask, _, _ = predictor.predict( point_coords = input_points, point_labels = input_label, multimask_output = False ) if isInvert: mask = np.logical_not(mask) mask = Image.fromarray(mask[0,:,:]) return mask def inpaint(img, mask, prompt): img = Image.fromarray(img) mask = Image.fromarray(mask) img = img.resize((512,512)) mask = mask.resize((512,512)) negative_prompts = """ duplicate,low quality, lowest quality, bad shape,bad anatomy, bad proportions, lowres,error,watermark,username,artistname, signature,text,jpeg artifacts,blurry,more than one person,simple background """ prompt_text = "Realistic professinal Headshot of a man for a profile pic" + prompt output = sd_pipeline(prompt = prompt_text, image = img, negative_prompt = negative_prompts, mask_image = mask).images[0] return output input_img.select(generate_mask, [input_img],[mask_img]) submit.click(inpaint, inputs=[input_img,mask_img,prompt_text], outputs = [output_img]) if __name__ == '__main__': genaieg.launch(debug = True)