itberrios commited on
Commit
395679e
·
1 Parent(s): 53e96bd

updated app

Browse files
Files changed (2) hide show
  1. app.py +130 -19
  2. model.py +72 -0
app.py CHANGED
@@ -1,34 +1,145 @@
1
- import streamlit as st
 
2
  from PIL import Image
 
 
 
 
3
 
4
  import torch
5
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
6
  from diffusers import StableDiffusionInpaintPipeline
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
10
 
11
- seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
12
- seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
13
 
 
 
14
 
15
- # get Stable Diffusion model for image inpainting
16
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
17
- "runwayml/stable-diffusion-inpainting",
18
- torch_dtype=torch.float16,
19
- ).to(device)
20
 
 
 
21
 
22
- st.title("Stable Edit - Edit your photos with Stable Diffusion!")
23
 
24
- # upload image
25
- filename = st.file_uploader("upload an image")
26
- image = Image.open(filename)
27
- st.image(image)
28
 
29
- # Select Area to edit
30
- selection = st.selectbox("Select Area(s) to edit", ("AutoSegment Area", "Draw Custom Area", "Suprise Me"))
 
 
 
 
31
 
32
- # TEMP - DEMO stuff
33
- x = st.slider('Select a value')
34
- st.write(x, 'squared is', x * x)
 
1
+ import numpy as np
2
+ import pandas as pd
3
  from PIL import Image
4
+ from collections import defaultdict
5
+
6
+ import streamlit as st
7
+ from streamlit_drawable_canvas import st_canvas
8
 
9
  import torch
10
  from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
11
  from diffusers import StableDiffusionInpaintPipeline
12
 
13
+ import matplotlib as mpl
14
+
15
+ from model import segment_image, inpaint
16
+
17
+
18
+ # define utils and helpers
19
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ def closest_number(n, m=8):
22
+ """ Obtains closest number to n that is divisble by m """
23
+ return int(n/m) * m
24
+
25
+
26
+ def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
27
+ # Create a canvas component
28
+ canvas_result = st_canvas(
29
+ fill_color="rgba(255, 165, 0, 0.3)",
30
+ stroke_width=2,
31
+ stroke_color="#000000",
32
+ background_image=image,
33
+ update_streamlit=True,
34
+ height=height,
35
+ width=width,
36
+ drawing_mode=drawing_mode,
37
+ point_display_radius=5,
38
+ key="canvas",
39
+ )
40
+
41
+ # get selections from mask
42
+ if canvas_result.json_data is not None:
43
+ objects = pd.json_normalize(canvas_result.json_data["objects"])
44
+ for col in objects.select_dtypes(include=["object"]).columns:
45
+ objects[col] = objects[col].astype("str")
46
+
47
+ if len(objects) > 0:
48
+ # st.dataframe(objects)
49
+
50
+ left_coords = objects.left.to_numpy()
51
+ top_coords = objects.top.to_numpy()
52
+ right_coords = left_coords + objects.width.to_numpy()
53
+ bottom_coords = top_coords + objects.height.to_numpy()
54
+
55
+ # add selections to mask
56
+ for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
57
+ cropped = image.crop((left, top, right, bottom))
58
+ st.image(cropped)
59
+ mask[top:bottom, left:right] = 255
60
+
61
+ st.header("Mask Created!")
62
+ st.image(mask)
63
+
64
+ return mask
65
+
66
+
67
+ def get_mask(image, edit_method, height, width):
68
+ mask = np.zeros((height, width), dtype=np.uint8)
69
+
70
+ if edit_method == "AutoSegment Area":
71
+
72
+ # get displayable segmented image
73
+ seg_prediction, segment_labels = segment_image(image)
74
+ seg = seg_prediction['segmentation'].cpu().numpy()
75
+ viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
76
+ seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
77
+
78
+ # display image
79
+ st.image(seg_image)
80
+
81
+ # prompt user to select valid labels to edit
82
+ seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
83
+ if seg_selections:
84
+ tgts = []
85
+ for s in seg_selections:
86
+ tgts.append(s[0])
87
+
88
+ mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
89
+ st.header("Mask Created!")
90
+ st.image(mask)
91
+
92
+ elif edit_method == "Draw Custom Area":
93
+ mask = get_mask_from_rectangles(image, mask, height, width)
94
+
95
+
96
+ return mask
97
+
98
+
99
+
100
+ if __name__ == '__main__':
101
+
102
+ st.title("Stable Edit - Edit your photos with Stable Diffusion!")
103
+
104
+ # upload image
105
+ filename = st.file_uploader("upload an image")
106
+ # filename = r"C:\Users\itber\Downloads\Fjord_Cycling.jpg"
107
+
108
+ sf = st.text_input("Please enter resizing scale factor to downsize image (default = 2)", value="2")
109
+ try:
110
+ sf = int(sf)
111
+ except:
112
+ sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
113
+ sf = 2
114
+
115
+ if filename:
116
+ image = Image.open(filename)
117
 
118
+ width, height = image.size
119
+ width, height = closest_number(width/sf), closest_number(height/sf)
120
+ image = image.resize((width, height))
121
 
122
+ st.image(image)
123
+ # st.write(f"{width} {height}")
124
 
125
+ # Select an editing method
126
+ edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
127
 
128
+ if edit_method:
129
+ mask = get_mask(image, edit_method, height, width)
 
 
 
130
 
131
+ # get inpainted images
132
+ prompt = st.text_input("Please enter prompt for image inpainting", value="")
133
 
134
+ st.write("Inpainting Images, patience is a virtue :)")
135
 
136
+ images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3)
 
 
 
137
 
138
+ # display all images
139
+ st.write("Original Image")
140
+ st.image(image)
141
+ for i, img in enumerate(images, 1):
142
+ st.write(f"result: {i}")
143
+ st.image(img)
144
 
145
+
 
 
model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
3
+ from diffusers import StableDiffusionInpaintPipeline
4
+
5
+
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+
8
+
9
+ # Image segmentation
10
+ seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
11
+ seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
12
+
13
+ def segment_image(image):
14
+ inputs = seg_processor(image, return_tensors="pt")
15
+
16
+ with torch.no_grad():
17
+ seg_outputs = seg_model(**inputs)
18
+
19
+ # get prediction dict
20
+ seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0]
21
+
22
+ # get segment labels dict
23
+ segment_labels = {}
24
+ for segment in seg_prediction['segments_info']:
25
+ segment_id = segment['id']
26
+ segment_label_id = segment['label_id']
27
+ segment_label = seg_model.config.id2label[segment_label_id]
28
+
29
+ segment_labels.update({segment_id : segment_label})
30
+
31
+ return seg_prediction, segment_labels
32
+
33
+ # Image inpainting
34
+ # get Stable Diffusion model for image inpainting
35
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
36
+ "runwayml/stable-diffusion-inpainting",
37
+ torch_dtype=torch.float16,
38
+ ).to(device)
39
+
40
+
41
+ def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3):
42
+ """ Uses Stable Diffusion model to inpaint image
43
+ Inputs:
44
+ image - input image (PIL or torch tensor)
45
+ mask - mask for inpainting same size as image (PIL or troch tensor)
46
+ W - size of image
47
+ H - size of mask
48
+ prompt - prompt for inpainting
49
+ seed - random seed
50
+ Outputs:
51
+ images - output images
52
+ """
53
+ generator = torch.Generator(device="cuda").manual_seed(seed)
54
+ images = pipe(
55
+ prompt=prompt,
56
+ image=image,
57
+ mask_image=mask, # ensure mask is same type as image
58
+ height=H,
59
+ width=W,
60
+ guidance_scale=guidance_scale,
61
+ generator=generator,
62
+ num_images_per_prompt=num_samples,
63
+ ).images
64
+
65
+ return images
66
+
67
+
68
+
69
+
70
+
71
+
72
+