pr4nav101 commited on
Commit
d6e5ca6
·
verified ·
1 Parent(s): 5e31674

Upload sdxl_sam.py

Browse files
Files changed (1) hide show
  1. sdxl_sam.py +107 -0
sdxl_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """SDXL-SAM.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/github/pranavsrinivasa/SDXL-SAM-HEADSHOT-CREATOR/blob/main/SDXL-SAM.ipynb
8
+ """
9
+
10
+ !pip install diffusers
11
+ !pip install transformers
12
+
13
+ from diffusers import StableDiffusionInpaintPipeline
14
+ import torch
15
+
16
+ model_id = 'stabilityai/stable-diffusion-2-inpainting'
17
+ sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(model_id,torch_dtype = torch.float16)
18
+ sd_pipeline = sd_pipeline.to("cuda")
19
+
20
+ !pip install gradio==3.48.0
21
+
22
+ !pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'
23
+ !pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision
24
+
25
+ !mkdir -p {HOME}/weights
26
+ !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights
27
+
28
+ import os
29
+
30
+ CHECKPOINT_PATH = os.path.join("{HOME}", "weights", "sam_vit_h_4b8939.pth")
31
+ print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
32
+ MODEL_TYPE = "vit_h"
33
+
34
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
35
+
36
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to("cuda")
37
+
38
+ predictor = SamPredictor(sam)
39
+
40
+ import gradio as gr
41
+ import numpy as np
42
+ from PIL import Image
43
+
44
+ selected_pixels = []
45
+ isInvert = 0
46
+
47
+ with gr.Blocks() as genaieg:
48
+ selected_pixels = []
49
+ isInvert = 0
50
+ with gr.Row():
51
+ input_img = gr.Image(label = 'Input')
52
+ mask_img = gr.Image(label = "Mask")
53
+ with gr.Row():
54
+ output_img = gr.Image(label = "Ouput")
55
+
56
+ def invertmask():
57
+ global isInvert
58
+ isInvert = not(isInvert)
59
+
60
+ with gr.Row():
61
+ prompt_text = gr.Textbox(line = 1,label = 'Prompt')
62
+ submit = gr.Button('Submit')
63
+ radio = gr.Radio(['Invert Mask'])
64
+ radio.select(fn = invertmask)
65
+
66
+ def generate_mask(image, evt: gr.SelectData):
67
+ selected_pixels.append(evt.index)
68
+ predictor.set_image(image)
69
+ input_points = np.array(selected_pixels)
70
+ input_label = np.ones(input_points.shape[0])
71
+ mask, _, _ = predictor.predict(
72
+ point_coords = input_points,
73
+ point_labels = input_label,
74
+ multimask_output = False
75
+ )
76
+ if isInvert:
77
+ mask = np.logical_not(mask)
78
+ mask = Image.fromarray(mask[0,:,:])
79
+ return mask
80
+
81
+
82
+ def inpaint(img, mask, prompt):
83
+ img = Image.fromarray(img)
84
+ mask = Image.fromarray(mask)
85
+ img = img.resize((512,512))
86
+ mask = mask.resize((512,512))
87
+ negative_prompts = """
88
+ duplicate,low quality, lowest quality, bad shape,bad anatomy,
89
+ bad proportions, lowres,error,watermark,username,artistname,
90
+ signature,text,jpeg artifacts,blurry,more than one person,simple background
91
+ """
92
+ prompt_text = "Realistic professinal Headshot of a man for a profile pic" + prompt
93
+ output = sd_pipeline(prompt = prompt_text,
94
+ image = img,
95
+ negative_prompt = negative_prompts,
96
+ mask_image = mask).images[0]
97
+
98
+ return output
99
+
100
+
101
+ input_img.select(generate_mask, [input_img],[mask_img])
102
+ submit.click(inpaint,
103
+ inputs=[input_img,mask_img,prompt_text],
104
+ outputs = [output_img])
105
+
106
+ if __name__ == '__main__':
107
+ genaieg.launch(debug = True)