michaelapplydesign commited on
Commit
8f75876
·
1 Parent(s): 09f6c11
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: RemoveFurnitureV1
3
  emoji: 👀
4
  colorFrom: red
5
  colorTo: indigo
 
1
  ---
2
+ title: V1
3
  emoji: 👀
4
  colorFrom: red
5
  colorTo: indigo
app.py CHANGED
@@ -45,7 +45,8 @@ def upscale(image):
45
  with gr.Blocks() as app:
46
  with gr.Row():
47
 
48
- gr.Button("FurnituRemove").click(removeFurniture,
 
49
  inputs=[gr.Image(label="img", type="pil"),
50
  gr.Image(label="mask", type="pil"),
51
  gr.Textbox(label="positive_prompt",value="empty room"),
@@ -64,9 +65,9 @@ with gr.Blocks() as app:
64
  gr.Image(),
65
  gr.Image(),
66
  gr.Image()])
 
 
 
 
67
 
68
- gr.Button("Segmentation").click(segmentation, inputs=gr.Image(type="pil"), outputs=gr.Image())
69
-
70
- gr.Button("Upscale").click(upscale, inputs=gr.Image(type="pil"), outputs=gr.Image())
71
-
72
- app.launch(debug=True)
 
45
  with gr.Blocks() as app:
46
  with gr.Row():
47
 
48
+ with gr.Column():
49
+ gr.Button("FurnituRemove").click(removeFurniture,
50
  inputs=[gr.Image(label="img", type="pil"),
51
  gr.Image(label="mask", type="pil"),
52
  gr.Textbox(label="positive_prompt",value="empty room"),
 
65
  gr.Image(),
66
  gr.Image(),
67
  gr.Image()])
68
+ with gr.Column():
69
+ gr.Button("Segmentation").click(segmentation, inputs=gr.Image(type="pil"), outputs=gr.Image())
70
+ with gr.Column():
71
+ gr.Button("Upscale").click(upscale, inputs=gr.Image(type="pil"), outputs=gr.Image())
72
 
73
+ app.launch(debug=True,share=True)
 
 
 
 
colors.py DELETED
@@ -1,343 +0,0 @@
1
- """Color mappings"""
2
- from typing import List, Dict
3
-
4
- TRIVIA = {
5
- "#B47878": "building;edifice",
6
- "#06E6E6": "sky",
7
- "#04C803": "tree",
8
- "#8C8C8C": "road;route",
9
- "#04FA07": "grass",
10
- "#96053D": "person;individual;someone;somebody;mortal;soul",
11
- "#CCFF04": "plant;flora;plant;life",
12
- "#787846": "earth;ground",
13
- "#FF09E0": "house",
14
- "#0066C8": "car;auto;automobile;machine;motorcar",
15
- "#3DE6FA": "water",
16
- "#FF3D06": "railing;rail",
17
- "#FF5C00": "arcade;machine",
18
- "#FFE000": "stairs;steps",
19
- "#00F5FF": "fan",
20
- "#FF008F": "step;stair",
21
- "#1F00FF": "stairway;staircase",
22
- "#FFD600": "radiator",
23
- }
24
-
25
- OBJECTS = {
26
- "#CC05FF": "bed",
27
- "#FF0633": "painting;picture",
28
- "#DCDCDC": "mirror",
29
- "#00FF14": "box",
30
- "#FF0000": "flower",
31
- "#FFA300": "book",
32
- "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
33
- "#F500FF": "pot;flowerpot",
34
- "#00FFCC": "vase",
35
- "#29FF00": "tray",
36
- "#8FFF00": "poster;posting;placard;notice;bill;card",
37
- "#5CFF00": "basket;handbasket",
38
- "#00ADFF": "screen;door;screen",
39
- }
40
-
41
-
42
- SITTING = {
43
- "#0B66FF": "sofa;couch;lounge",
44
- "#CC4603": "chair",
45
- "#07FFE0": "seat",
46
- "#08FFD6": "armchair",
47
- "#FFC207": "cushion",
48
- "#00EBFF": "pillow",
49
- "#00D6FF": "stool",
50
- "#1400FF": "blanket;cover",
51
- "#0A00FF": "swivel;chair",
52
- "#FF9900": "ottoman;pouf;pouffe;puff;hassock",
53
- }
54
-
55
- LIGHTING = {
56
- "#E0FF08": "lamp",
57
- "#FFAD00": "light;light;source",
58
- "#001FFF": "chandelier;pendant;pendent",
59
- }
60
-
61
- TABLES = {
62
- "#FF0652": "table",
63
- "#0AFF47": "desk",
64
- }
65
-
66
- CLOSETS = {
67
- "#E005FF": "cabinet",
68
- "#FF0747": "shelf",
69
- "#07FFFF": "wardrobe;closet;press",
70
- "#0633FF": "chest;of;drawers;chest;bureau;dresser",
71
- "#0000FF": "case;display;case;showcase;vitrine",
72
- }
73
-
74
-
75
- BATHROOM = {
76
- "#6608FF": "bathtub;bathing;tub;bath;tub",
77
- "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne",
78
- "#0085FF": "shower",
79
- "#FF0066": "towel",
80
- }
81
-
82
- WINDOWS = {
83
- "#FF3307": "curtain;drape;drapery;mantle;pall",
84
- "#E6E6E6": "windowpane;window",
85
- "#00FF3D": "awning;sunshade;sunblind",
86
- "#003DFF": "blind;screen",
87
- }
88
-
89
- FLOOR = {
90
- "#FF095C": "rug;carpet;carpeting",
91
- "#503232": "floor;flooring",
92
- }
93
-
94
- INTERIOR = {
95
- "#787878": "wall",
96
- "#787850": "ceiling",
97
- "#08FF33": "door;double;door",
98
- }
99
-
100
- KITCHEN = {
101
- "#00FF29": "kitchen;island",
102
- "#14FF00": "refrigerator;icebox",
103
- "#00A3FF": "sink",
104
- "#EB0CFF": "counter",
105
- "#D6FF00": "dishwasher;dish;washer;dishwashing;machine",
106
- "#FF00EB": "microwave;microwave;oven",
107
- "#47FF00": "oven",
108
- "#66FF00": "clock",
109
- "#00FFB8": "plate",
110
- "#19C2C2": "glass;drinking;glass",
111
- "#00FF99": "bar",
112
- "#00FF0A": "bottle",
113
- "#FF7000": "buffet;counter;sideboard",
114
- "#B800FF": "washer;automatic;washer;washing;machine",
115
- "#00FF70": "coffee;table;cocktail;table",
116
- "#008FFF": "countertop",
117
- "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove",
118
- }
119
-
120
- LIVINGROOM = {
121
- "#FA0A0F": "fireplace;hearth;open;fireplace",
122
- "#FF4700": "pool;table;billiard;table;snooker;table",
123
- }
124
-
125
- OFFICE = {
126
- "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
127
- "#00FFF5": "bookcase",
128
- "#0633FF": "chest;of;drawers;chest;bureau;dresser",
129
- "#005CFF": "monitor;monitoring;device",
130
- }
131
-
132
-
133
- COLOR_MAPPING_CATEGORY_ = {
134
- 'keep background': {'#FFFFFF': 'background'},
135
- 'trivia': TRIVIA,
136
- 'objects': OBJECTS,
137
- 'sitting': SITTING,
138
- 'lighting': LIGHTING,
139
- 'tables': TABLES,
140
- 'closets': CLOSETS,
141
- 'bathroom': BATHROOM,
142
- 'windows': WINDOWS,
143
- 'floor': FLOOR,
144
- 'interior': INTERIOR,
145
- 'kitchen': KITCHEN,
146
- 'livingroom': LIVINGROOM,
147
- 'office': OFFICE}
148
-
149
-
150
- COLOR_MAPPING_ = {
151
- '#FFFFFF': 'background',
152
- "#787878": "wall",
153
- "#B47878": "building;edifice",
154
- "#06E6E6": "sky",
155
- "#503232": "floor;flooring",
156
- "#04C803": "tree",
157
- "#787850": "ceiling",
158
- "#8C8C8C": "road;route",
159
- "#CC05FF": "bed",
160
- "#E6E6E6": "windowpane;window",
161
- "#04FA07": "grass",
162
- "#E005FF": "cabinet",
163
- "#EBFF07": "sidewalk;pavement",
164
- "#96053D": "person;individual;someone;somebody;mortal;soul",
165
- "#787846": "earth;ground",
166
- "#08FF33": "door;double;door",
167
- "#FF0652": "table",
168
- "#8FFF8C": "mountain;mount",
169
- "#CCFF04": "plant;flora;plant;life",
170
- "#FF3307": "curtain;drape;drapery;mantle;pall",
171
- "#CC4603": "chair",
172
- "#0066C8": "car;auto;automobile;machine;motorcar",
173
- "#3DE6FA": "water",
174
- "#FF0633": "painting;picture",
175
- "#0B66FF": "sofa;couch;lounge",
176
- "#FF0747": "shelf",
177
- "#FF09E0": "house",
178
- "#0907E6": "sea",
179
- "#DCDCDC": "mirror",
180
- "#FF095C": "rug;carpet;carpeting",
181
- "#7009FF": "field",
182
- "#08FFD6": "armchair",
183
- "#07FFE0": "seat",
184
- "#FFB806": "fence;fencing",
185
- "#0AFF47": "desk",
186
- "#FF290A": "rock;stone",
187
- "#07FFFF": "wardrobe;closet;press",
188
- "#E0FF08": "lamp",
189
- "#6608FF": "bathtub;bathing;tub;bath;tub",
190
- "#FF3D06": "railing;rail",
191
- "#FFC207": "cushion",
192
- "#FF7A08": "base;pedestal;stand",
193
- "#00FF14": "box",
194
- "#FF0829": "column;pillar",
195
- "#FF0599": "signboard;sign",
196
- "#0633FF": "chest;of;drawers;chest;bureau;dresser",
197
- "#EB0CFF": "counter",
198
- "#A09614": "sand",
199
- "#00A3FF": "sink",
200
- "#8C8C8C": "skyscraper",
201
- "#FA0A0F": "fireplace;hearth;open;fireplace",
202
- "#14FF00": "refrigerator;icebox",
203
- "#1FFF00": "grandstand;covered;stand",
204
- "#FF1F00": "path",
205
- "#FFE000": "stairs;steps",
206
- "#99FF00": "runway",
207
- "#0000FF": "case;display;case;showcase;vitrine",
208
- "#FF4700": "pool;table;billiard;table;snooker;table",
209
- "#00EBFF": "pillow",
210
- "#00ADFF": "screen;door;screen",
211
- "#1F00FF": "stairway;staircase",
212
- "#0BC8C8": "river",
213
- "#FF5200": "bridge;span",
214
- "#00FFF5": "bookcase",
215
- "#003DFF": "blind;screen",
216
- "#00FF70": "coffee;table;cocktail;table",
217
- "#00FF85": "toilet;can;commode;crapper;pot;potty;stool;throne",
218
- "#FF0000": "flower",
219
- "#FFA300": "book",
220
- "#FF6600": "hill",
221
- "#C2FF00": "bench",
222
- "#008FFF": "countertop",
223
- "#33FF00": "stove;kitchen;stove;range;kitchen;range;cooking;stove",
224
- "#0052FF": "palm;palm;tree",
225
- "#00FF29": "kitchen;island",
226
- "#00FFAD": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
227
- "#0A00FF": "swivel;chair",
228
- "#ADFF00": "boat",
229
- "#00FF99": "bar",
230
- "#FF5C00": "arcade;machine",
231
- "#FF00FF": "hovel;hut;hutch;shack;shanty",
232
- "#FF00F5": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
233
- "#FF0066": "towel",
234
- "#FFAD00": "light;light;source",
235
- "#FF0014": "truck;motortruck",
236
- "#FFB8B8": "tower",
237
- "#001FFF": "chandelier;pendant;pendent",
238
- "#00FF3D": "awning;sunshade;sunblind",
239
- "#0047FF": "streetlight;street;lamp",
240
- "#FF00CC": "booth;cubicle;stall;kiosk",
241
- "#00FFC2": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
242
- "#00FF52": "airplane;aeroplane;plane",
243
- "#000AFF": "dirt;track",
244
- "#0070FF": "apparel;wearing;apparel;dress;clothes",
245
- "#3300FF": "pole",
246
- "#00C2FF": "land;ground;soil",
247
- "#007AFF": "bannister;banister;balustrade;balusters;handrail",
248
- "#00FFA3": "escalator;moving;staircase;moving;stairway",
249
- "#FF9900": "ottoman;pouf;pouffe;puff;hassock",
250
- "#00FF0A": "bottle",
251
- "#FF7000": "buffet;counter;sideboard",
252
- "#8FFF00": "poster;posting;placard;notice;bill;card",
253
- "#5200FF": "stage",
254
- "#A3FF00": "van",
255
- "#FFEB00": "ship",
256
- "#08B8AA": "fountain",
257
- "#8500FF": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
258
- "#00FF5C": "canopy",
259
- "#B800FF": "washer;automatic;washer;washing;machine",
260
- "#FF001F": "plaything;toy",
261
- "#00B8FF": "swimming;pool;swimming;bath;natatorium",
262
- "#00D6FF": "stool",
263
- "#FF0070": "barrel;cask",
264
- "#5CFF00": "basket;handbasket",
265
- "#00E0FF": "waterfall;falls",
266
- "#70E0FF": "tent;collapsible;shelter",
267
- "#46B8A0": "bag",
268
- "#A300FF": "minibike;motorbike",
269
- "#9900FF": "cradle",
270
- "#47FF00": "oven",
271
- "#FF00A3": "ball",
272
- "#FFCC00": "food;solid;food",
273
- "#FF008F": "step;stair",
274
- "#00FFEB": "tank;storage;tank",
275
- "#85FF00": "trade;name;brand;name;brand;marque",
276
- "#FF00EB": "microwave;microwave;oven",
277
- "#F500FF": "pot;flowerpot",
278
- "#FF007A": "animal;animate;being;beast;brute;creature;fauna",
279
- "#FFF500": "bicycle;bike;wheel;cycle",
280
- "#0ABED4": "lake",
281
- "#D6FF00": "dishwasher;dish;washer;dishwashing;machine",
282
- "#00CCFF": "screen;silver;screen;projection;screen",
283
- "#1400FF": "blanket;cover",
284
- "#FFFF00": "sculpture",
285
- "#0099FF": "hood;exhaust;hood",
286
- "#0029FF": "sconce",
287
- "#00FFCC": "vase",
288
- "#2900FF": "traffic;light;traffic;signal;stoplight",
289
- "#29FF00": "tray",
290
- "#AD00FF": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
291
- "#00F5FF": "fan",
292
- "#4700FF": "pier;wharf;wharfage;dock",
293
- "#7A00FF": "crt;screen",
294
- "#00FFB8": "plate",
295
- "#005CFF": "monitor;monitoring;device",
296
- "#B8FF00": "bulletin;board;notice;board",
297
- "#0085FF": "shower",
298
- "#FFD600": "radiator",
299
- "#19C2C2": "glass;drinking;glass",
300
- "#66FF00": "clock",
301
- "#5C00FF": "flag",
302
- }
303
-
304
- def ade_palette() -> List[List[int]]:
305
- """ADE20K palette that maps each class to RGB values."""
306
- return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
307
- [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
308
- [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
309
- [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
310
- [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
311
- [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
312
- [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
313
- [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
314
- [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
315
- [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
316
- [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
317
- [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
318
- [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
319
- [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
320
- [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
321
- [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
322
- [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
323
- [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
324
- [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
325
- [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
326
- [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
327
- [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
328
- [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
329
- [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
330
- [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
331
- [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
332
- [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
333
- [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
334
- [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
335
- [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
336
- [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
337
- [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
338
- [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
339
- [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
340
- [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
341
- [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
342
- [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
343
- [102, 255, 0], [92, 0, 255]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py DELETED
@@ -1,43 +0,0 @@
1
- """File with configs"""
2
- from palette import COLOR_MAPPING_, COLOR_MAPPING
3
-
4
- # HEIGHT = 1024
5
- # WIDTH = 1024
6
- # HEIGHT = 512
7
- # WIDTH = 512
8
- #
9
-
10
- # def setResoluton(res):
11
- # global HEIGHT, WIDTH
12
- # HEIGHT = res
13
- # WIDTH = res
14
-
15
- def to_rgb(color: str) -> tuple:
16
- """Convert hex color to rgb.
17
- Args:
18
- color (str): hex color
19
- Returns:
20
- tuple: rgb color
21
- """
22
- return tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
23
-
24
- COLOR_NAMES = list(COLOR_MAPPING.keys())
25
- COLOR_RGB = [to_rgb(k) for k in COLOR_MAPPING_.keys()] + [(0, 0, 0), (255, 255, 255)]
26
- INVERSE_COLORS = {v: to_rgb(k) for k, v in COLOR_MAPPING_.items()}
27
- COLOR_MAPPING_RGB = {to_rgb(k): v for k, v in COLOR_MAPPING_.items()}
28
-
29
- def map_colors(color: str) -> str:
30
- """Map color to hex value.
31
- Args:
32
- color (str): color name
33
- Returns:
34
- str: hex value
35
- """
36
- return COLOR_MAPPING[color]
37
-
38
- def map_colors_rgb(color: tuple) -> str:
39
- return COLOR_MAPPING_RGB[color]
40
-
41
-
42
- POS_PROMPT = "tree, sky, cloud, scenery, outdoors, grass, flowers, sunlight, beautiful, ultra detailed beautiful landscape, architectural renderings vegetation, high res, best high quality landscape, outdoor lighting, sunshine, 4k, 8k, realistic"
43
- NEG_PROMPT= "lowres, deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, mutated hands and fingers, out of frame"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
empty_room.jpg DELETED
Binary file (62.4 kB)
 
helpers.py CHANGED
@@ -9,7 +9,6 @@ def flush():
9
  torch.cuda.empty_cache()
10
 
11
 
12
-
13
  def convolution(mask: Image.Image, size=9) -> Image:
14
  """Method to blur the mask
15
  Args:
 
9
  torch.cuda.empty_cache()
10
 
11
 
 
12
  def convolution(mask: Image.Image, size=9) -> Image:
13
  """Method to blur the mask
14
  Args:
models.py CHANGED
@@ -1,66 +1,17 @@
1
- """This file contains methods for inference and image generation."""
2
- import logging
3
- from typing import List, Tuple, Dict
4
 
 
 
5
 
6
  import torch
7
  import numpy as np
8
  from PIL import Image
9
 
10
- from diffusers import ControlNetModel, UniPCMultistepScheduler
11
-
12
- from palette import ade_palette
13
- from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
14
  from helpers import flush, postprocess_image_masking, convolution
15
- from pipelines import ControlNetPipeline, SDPipeline, get_inpainting_pipeline, get_controlnet
16
 
17
  LOGGING = logging.getLogger(__name__)
18
 
19
 
20
- @torch.inference_mode()
21
- def make_image_controlnet(image: np.ndarray,
22
- mask_image: np.ndarray,
23
- controlnet_conditioning_image: np.ndarray,
24
- positive_prompt: str, negative_prompt: str,
25
- seed: int = 2356132) -> List[Image.Image]:
26
- """Method to make image using controlnet
27
- Args:
28
- image (np.ndarray): input image
29
- mask_image (np.ndarray): mask image
30
- controlnet_conditioning_image (np.ndarray): conditioning image
31
- positive_prompt (str): positive prompt string
32
- negative_prompt (str): negative prompt string
33
- seed (int, optional): seed. Defaults to 2356132.
34
- Returns:
35
- List[Image.Image]: list of generated images
36
- """
37
-
38
- pipe = get_controlnet()
39
- flush()
40
-
41
- image = Image.fromarray(image).convert("RGB")
42
- controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB")#.filter(ImageFilter.GaussianBlur(radius = 9))
43
- mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB")
44
- mask_image_postproc = convolution(mask_image)
45
-
46
-
47
-
48
- generated_image = pipe(
49
- prompt=positive_prompt,
50
- negative_prompt=negative_prompt,
51
- num_inference_steps=50,
52
- strength=1.00,
53
- guidance_scale=7.0,
54
- generator=[torch.Generator(device="cuda").manual_seed(seed)],
55
- image=image,
56
- mask_image=mask_image,
57
- controlnet_conditioning_image=controlnet_conditioning_image,
58
- ).images[0]
59
- generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc)
60
-
61
- return generated_image
62
-
63
-
64
  @torch.inference_mode()
65
  def make_inpainting(positive_prompt: str,
66
  image: Image,
 
 
 
 
1
 
2
+ import logging
3
+ from typing import List
4
 
5
  import torch
6
  import numpy as np
7
  from PIL import Image
8
 
 
 
 
 
9
  from helpers import flush, postprocess_image_masking, convolution
10
+ from pipelines import get_inpainting_pipeline
11
 
12
  LOGGING = logging.getLogger(__name__)
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @torch.inference_mode()
16
  def make_inpainting(positive_prompt: str,
17
  image: Image,
palette.py DELETED
@@ -1,38 +0,0 @@
1
- """This file contains color information"""
2
- from typing import List, Dict
3
- from colors import COLOR_MAPPING_, COLOR_MAPPING_CATEGORY_, ade_palette
4
-
5
-
6
- def convert_hex_to_rgba(hex_code: str) -> str:
7
- """Convert hex code to rgba.
8
- Args:
9
- hex_code (str): hex string
10
- Returns:
11
- str: rgba string
12
- """
13
- hex_code = hex_code.lstrip('#')
14
- return "rgba(" + str(int(hex_code[0:2], 16)) + ", " + str(int(hex_code[2:4], 16)) + ", " + str(int(hex_code[4:6], 16)) + ", 1.0)"
15
-
16
-
17
- def convert_dict_to_rgba(color_dict: Dict) -> Dict:
18
- """Convert hex code to rgba for all elements in a dictionary.
19
- Args:
20
- color_dict (Dict): color dictionary
21
- Returns:
22
- Dict: color dictionary with rgba values
23
- """
24
- updated_dict = {}
25
- for k, v in color_dict.items():
26
- updated_dict[convert_hex_to_rgba(k)] = v
27
- return updated_dict
28
-
29
-
30
- def convert_nested_dict_to_rgba(nested_dict):
31
- updated_dict = {}
32
- for k, v in nested_dict.items():
33
- updated_dict[k] = convert_dict_to_rgba(v)
34
- return updated_dict
35
-
36
-
37
- COLOR_MAPPING = convert_dict_to_rgba(COLOR_MAPPING_)
38
- COLOR_MAPPING_CATEGORY = convert_nested_dict_to_rgba(COLOR_MAPPING_CATEGORY_)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipelines.py CHANGED
@@ -1,69 +1,12 @@
1
  import logging
2
- from typing import List, Tuple, Dict
3
-
4
-
5
  import torch
6
- import gc
7
  import time
8
- import numpy as np
9
- from PIL import Image
10
- from time import perf_counter
11
- from contextlib import contextmanager
12
- from scipy.signal import fftconvolve
13
- from PIL import ImageFilter
14
-
15
- from diffusers import ControlNetModel, UniPCMultistepScheduler
16
  from diffusers import StableDiffusionInpaintPipeline
17
 
18
  # from config import WIDTH, HEIGHT
19
- from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
20
  from helpers import flush
21
 
22
  LOGGING = logging.getLogger(__name__)
23
-
24
- class ControlNetPipeline:
25
- def __init__(self):
26
- self.in_use = False
27
- self.controlnet = ControlNetModel.from_pretrained(
28
- "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
29
-
30
- self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
31
- "runwayml/stable-diffusion-inpainting",
32
- controlnet=self.controlnet,
33
- safety_checker=None,
34
- torch_dtype=torch.float16
35
- )
36
-
37
- self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
38
- self.pipe.enable_xformers_memory_efficient_attention()
39
- self.pipe = self.pipe.to("cuda")
40
-
41
- self.waiting_queue = []
42
- self.count = 0
43
-
44
- @property
45
- def queue_size(self):
46
- return len(self.waiting_queue)
47
-
48
- def __call__(self, **kwargs):
49
- self.count += 1
50
- number = self.count
51
-
52
- self.waiting_queue.append(number)
53
-
54
- # wait until the next number in the queue is the current number
55
- while self.waiting_queue[0] != number:
56
- print(f"Wait for your turn {number} in queue {self.waiting_queue}")
57
- time.sleep(0.5)
58
- pass
59
-
60
- # it's your turn, so remove the number from the queue
61
- # and call the function
62
- print("It's the turn of", self.count)
63
- results = self.pipe(**kwargs)
64
- self.waiting_queue.pop(0)
65
- flush()
66
- return results
67
 
68
  class SDPipeline:
69
  def __init__(self):
@@ -104,18 +47,6 @@ class SDPipeline:
104
  return results
105
 
106
 
107
-
108
-
109
- def get_controlnet():
110
- """Method to load the controlnet model
111
- Returns:
112
- ControlNetModel: controlnet model
113
- """
114
- pipe = ControlNetPipeline()
115
- return pipe
116
-
117
-
118
-
119
  def get_inpainting_pipeline():
120
  """Method to load the inpainting pipeline
121
  Returns:
 
1
  import logging
 
 
 
2
  import torch
 
3
  import time
 
 
 
 
 
 
 
 
4
  from diffusers import StableDiffusionInpaintPipeline
5
 
6
  # from config import WIDTH, HEIGHT
 
7
  from helpers import flush
8
 
9
  LOGGING = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class SDPipeline:
12
  def __init__(self):
 
47
  return results
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def get_inpainting_pipeline():
51
  """Method to load the inpainting pipeline
52
  Returns:
preprocessing.py DELETED
@@ -1,55 +0,0 @@
1
- """Preprocessing methods"""
2
- import logging
3
- from typing import List, Tuple
4
-
5
- import numpy as np
6
- from PIL import Image, ImageFilter
7
-
8
-
9
- from config import COLOR_RGB
10
- # from enhance_config import ENHANCE_SETTINGS
11
-
12
- LOGGING = logging.getLogger(__name__)
13
-
14
-
15
- def preprocess_seg_mask(canvas_seg, real_seg: Image.Image = None) -> Tuple[np.ndarray, np.ndarray]:
16
- """Preprocess the segmentation mask.
17
- Args:
18
- canvas_seg: segmentation canvas
19
- real_seg (Image.Image, optional): segmentation mask. Defaults to None.
20
- Returns:
21
- Tuple[np.ndarray, np.ndarray]: segmentation mask, segmentation mask with overlay
22
- """
23
- # get unique colors in the segmentation
24
- image_seg = canvas_seg.image_data.copy()[:, :, :3]
25
-
26
- # average the colors of the segmentation masks
27
- average_color = np.mean(image_seg, axis=(2))
28
- mask = average_color[:, :] > 0
29
- if mask.sum() > 0:
30
- mask = mask * 1
31
-
32
- unique_colors = np.unique(image_seg.reshape(-1, image_seg.shape[-1]), axis=0)
33
- unique_colors = [tuple(color) for color in unique_colors]
34
-
35
- unique_colors = [color for color in unique_colors if np.sum(
36
- np.all(image_seg == color, axis=-1)) > 100]
37
-
38
- unique_colors_exact = [color for color in unique_colors if color in COLOR_RGB]
39
-
40
- if real_seg is not None:
41
- overlay_seg = np.array(real_seg)
42
-
43
- unique_colors = np.unique(overlay_seg.reshape(-1, overlay_seg.shape[-1]), axis=0)
44
- unique_colors = [tuple(color) for color in unique_colors]
45
-
46
- for color in unique_colors_exact:
47
- if color != (255, 255, 255) and color != (0, 0, 0):
48
- overlay_seg[np.all(image_seg == color, axis=-1)] = color
49
- image_seg = overlay_seg
50
-
51
- return mask, image_seg
52
-
53
-
54
-
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
- streamlit==1.20.0
2
- streamlit-drawable-canvas==0.9.0
3
  diffusers==0.15.0
4
  xformers==0.0.16
5
  transformers==4.28.0
@@ -8,7 +6,6 @@ git+https://github.com/huggingface/accelerate.git
8
  opencv-python-headless==4.7.0.72
9
  scipy==1.10.0
10
  python-docx
11
- extra-streamlit-components==0.1.56
12
  triton
13
  altair<5
14
  gradio
 
 
 
1
  diffusers==0.15.0
2
  xformers==0.0.16
3
  transformers==4.28.0
 
6
  opencv-python-headless==4.7.0.72
7
  scipy==1.10.0
8
  python-docx
 
9
  triton
10
  altair<5
11
  gradio
segmentation.py DELETED
@@ -1,54 +0,0 @@
1
- import logging
2
- from typing import List, Tuple, Dict
3
-
4
- import torch
5
- import gc
6
- import numpy as np
7
- from PIL import Image
8
-
9
- from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
10
-
11
- from palette import ade_palette
12
-
13
- LOGGING = logging.getLogger(__name__)
14
-
15
-
16
- def flush():
17
- gc.collect()
18
- torch.cuda.empty_cache()
19
-
20
-
21
- def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
22
- """Method to load the segmentation pipeline
23
- Returns:
24
- Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
25
- """
26
- image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
27
- image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
28
- "openmmlab/upernet-convnext-small")
29
- return image_processor, image_segmentor
30
-
31
-
32
- @torch.inference_mode()
33
- @torch.autocast('cuda')
34
- def segment_image(image: Image) -> Image:
35
- """Method to segment image
36
- Args:
37
- image (Image): input image
38
- Returns:
39
- Image: segmented image
40
- """
41
- image_processor, image_segmentor = get_segmentation_pipeline()
42
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
43
- with torch.no_grad():
44
- outputs = image_segmentor(pixel_values)
45
-
46
- seg = image_processor.post_process_semantic_segmentation(
47
- outputs, target_sizes=[image.size[::-1]])[0]
48
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
49
- palette = np.array(ade_palette())
50
- for label, color in enumerate(palette):
51
- color_seg[seg == label, :] = color
52
- color_seg = color_seg.astype(np.uint8)
53
- seg_image = Image.fromarray(color_seg).convert('RGB')
54
- return seg_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_diffusion_controlnet_inpaint_img2img.py DELETED
@@ -1,1112 +0,0 @@
1
- """This file contains the StableDiffusionControlNetInpaintImg2ImgPipeline class from the
2
- community pipelines from the diffusers library of HuggingFace.
3
- """
4
- # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
5
-
6
- import inspect
7
- from typing import Any, Callable, Dict, List, Optional, Union
8
-
9
- import numpy as np
10
- import PIL.Image
11
- import torch
12
- import torch.nn.functional as F
13
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
-
15
- from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
16
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
17
- from diffusers.schedulers import KarrasDiffusionSchedulers
18
- from diffusers.utils import (
19
- PIL_INTERPOLATION,
20
- is_accelerate_available,
21
- is_accelerate_version,
22
- randn_tensor,
23
- replace_example_docstring,
24
- )
25
-
26
-
27
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
-
29
- EXAMPLE_DOC_STRING = """
30
- Examples:
31
- ```py
32
- >>> import numpy as np
33
- >>> import torch
34
- >>> from PIL import Image
35
- >>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
36
- >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
37
- >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
38
- >>> from diffusers.utils import load_image
39
- >>> def ade_palette():
40
- return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
41
- [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
42
- [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
43
- [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
44
- [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
45
- [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
46
- [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
47
- [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
48
- [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
49
- [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
50
- [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
51
- [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
52
- [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
53
- [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
54
- [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
55
- [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
56
- [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
57
- [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
58
- [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
59
- [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
60
- [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
61
- [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
62
- [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
63
- [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
64
- [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
65
- [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
66
- [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
67
- [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
68
- [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
69
- [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
70
- [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
71
- [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
72
- [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
73
- [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
74
- [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
75
- [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
76
- [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
77
- [102, 255, 0], [92, 0, 255]]
78
- >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
79
- >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
80
- >>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
81
- "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
82
- )
83
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
84
- >>> pipe.enable_xformers_memory_efficient_attention()
85
- >>> pipe.enable_model_cpu_offload()
86
- >>> def image_to_seg(image):
87
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
88
- with torch.no_grad():
89
- outputs = image_segmentor(pixel_values)
90
- seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
91
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
92
- palette = np.array(ade_palette())
93
- for label, color in enumerate(palette):
94
- color_seg[seg == label, :] = color
95
- color_seg = color_seg.astype(np.uint8)
96
- seg_image = Image.fromarray(color_seg)
97
- return seg_image
98
- >>> image = load_image(
99
- "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
100
- )
101
- >>> mask_image = load_image(
102
- "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
103
- )
104
- >>> controlnet_conditioning_image = image_to_seg(image)
105
- >>> image = pipe(
106
- "Face of a yellow cat, high resolution, sitting on a park bench",
107
- image,
108
- mask_image,
109
- controlnet_conditioning_image,
110
- num_inference_steps=20,
111
- ).images[0]
112
- >>> image.save("out.png")
113
- ```
114
- """
115
-
116
-
117
- def prepare_image(image):
118
- if isinstance(image, torch.Tensor):
119
- # Batch single image
120
- if image.ndim == 3:
121
- image = image.unsqueeze(0)
122
-
123
- image = image.to(dtype=torch.float32)
124
- else:
125
- # preprocess image
126
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
127
- image = [image]
128
-
129
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
130
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
131
- image = np.concatenate(image, axis=0)
132
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
133
- image = np.concatenate([i[None, :] for i in image], axis=0)
134
-
135
- image = image.transpose(0, 3, 1, 2)
136
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
137
-
138
- return image
139
-
140
-
141
- def prepare_mask_image(mask_image):
142
- if isinstance(mask_image, torch.Tensor):
143
- if mask_image.ndim == 2:
144
- # Batch and add channel dim for single mask
145
- mask_image = mask_image.unsqueeze(0).unsqueeze(0)
146
- elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
147
- # Single mask, the 0'th dimension is considered to be
148
- # the existing batch size of 1
149
- mask_image = mask_image.unsqueeze(0)
150
- elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
151
- # Batch of mask, the 0'th dimension is considered to be
152
- # the batching dimension
153
- mask_image = mask_image.unsqueeze(1)
154
-
155
- # Binarize mask
156
- mask_image[mask_image < 0.5] = 0
157
- mask_image[mask_image >= 0.5] = 1
158
- else:
159
- # preprocess mask
160
- if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
161
- mask_image = [mask_image]
162
-
163
- if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
164
- mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
165
- mask_image = mask_image.astype(np.float32) / 255.0
166
- elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
167
- mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
168
-
169
- mask_image[mask_image < 0.5] = 0
170
- mask_image[mask_image >= 0.5] = 1
171
- mask_image = torch.from_numpy(mask_image)
172
-
173
- return mask_image
174
-
175
-
176
- def prepare_controlnet_conditioning_image(
177
- controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
178
- ):
179
- if not isinstance(controlnet_conditioning_image, torch.Tensor):
180
- if isinstance(controlnet_conditioning_image, PIL.Image.Image):
181
- controlnet_conditioning_image = [controlnet_conditioning_image]
182
-
183
- if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
184
- controlnet_conditioning_image = [
185
- np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
186
- for i in controlnet_conditioning_image
187
- ]
188
- controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
189
- controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
190
- controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
191
- controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
192
- elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
193
- controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
194
-
195
- image_batch_size = controlnet_conditioning_image.shape[0]
196
-
197
- if image_batch_size == 1:
198
- repeat_by = batch_size
199
- else:
200
- # image batch size is the same as prompt batch size
201
- repeat_by = num_images_per_prompt
202
-
203
- controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
204
-
205
- controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
206
-
207
- return controlnet_conditioning_image
208
-
209
-
210
- class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
211
- """
212
- Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
213
- """
214
-
215
- _optional_components = ["safety_checker", "feature_extractor"]
216
-
217
- def __init__(
218
- self,
219
- vae: AutoencoderKL,
220
- text_encoder: CLIPTextModel,
221
- tokenizer: CLIPTokenizer,
222
- unet: UNet2DConditionModel,
223
- controlnet: ControlNetModel,
224
- scheduler: KarrasDiffusionSchedulers,
225
- safety_checker: StableDiffusionSafetyChecker,
226
- feature_extractor: CLIPFeatureExtractor,
227
- requires_safety_checker: bool = True,
228
- ):
229
- super().__init__()
230
-
231
- if safety_checker is None and requires_safety_checker:
232
- logger.warning(
233
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
234
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
235
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
236
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
237
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
238
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
239
- )
240
-
241
- if safety_checker is not None and feature_extractor is None:
242
- raise ValueError(
243
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
244
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
245
- )
246
-
247
- self.register_modules(
248
- vae=vae,
249
- text_encoder=text_encoder,
250
- tokenizer=tokenizer,
251
- unet=unet,
252
- controlnet=controlnet,
253
- scheduler=scheduler,
254
- safety_checker=safety_checker,
255
- feature_extractor=feature_extractor,
256
- )
257
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
258
- self.register_to_config(requires_safety_checker=requires_safety_checker)
259
-
260
- def enable_vae_slicing(self):
261
- r"""
262
- Enable sliced VAE decoding.
263
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
264
- steps. This is useful to save some memory and allow larger batch sizes.
265
- """
266
- self.vae.enable_slicing()
267
-
268
- def disable_vae_slicing(self):
269
- r"""
270
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
271
- computing decoding in one step.
272
- """
273
- self.vae.disable_slicing()
274
-
275
- def enable_sequential_cpu_offload(self, gpu_id=0):
276
- r"""
277
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
278
- text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
279
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
280
- Note that offloading happens on a submodule basis. Memory savings are higher than with
281
- `enable_model_cpu_offload`, but performance is lower.
282
- """
283
- if is_accelerate_available():
284
- from accelerate import cpu_offload
285
- else:
286
- raise ImportError("Please install accelerate via `pip install accelerate`")
287
-
288
- device = torch.device(f"cuda:{gpu_id}")
289
-
290
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
291
- cpu_offload(cpu_offloaded_model, device)
292
-
293
- if self.safety_checker is not None:
294
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
295
-
296
- def enable_model_cpu_offload(self, gpu_id=0):
297
- r"""
298
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
299
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
300
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
301
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
302
- """
303
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
304
- from accelerate import cpu_offload_with_hook
305
- else:
306
- raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
307
-
308
- device = torch.device(f"cuda:{gpu_id}")
309
-
310
- hook = None
311
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
312
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
313
-
314
- if self.safety_checker is not None:
315
- # the safety checker can offload the vae again
316
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
317
-
318
- # control net hook has be manually offloaded as it alternates with unet
319
- cpu_offload_with_hook(self.controlnet, device)
320
-
321
- # We'll offload the last model manually.
322
- self.final_offload_hook = hook
323
-
324
- @property
325
- def _execution_device(self):
326
- r"""
327
- Returns the device on which the pipeline's models will be executed. After calling
328
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
329
- hooks.
330
- """
331
- if not hasattr(self.unet, "_hf_hook"):
332
- return self.device
333
- for module in self.unet.modules():
334
- if (
335
- hasattr(module, "_hf_hook")
336
- and hasattr(module._hf_hook, "execution_device")
337
- and module._hf_hook.execution_device is not None
338
- ):
339
- return torch.device(module._hf_hook.execution_device)
340
- return self.device
341
-
342
- def _encode_prompt(
343
- self,
344
- prompt,
345
- device,
346
- num_images_per_prompt,
347
- do_classifier_free_guidance,
348
- negative_prompt=None,
349
- prompt_embeds: Optional[torch.FloatTensor] = None,
350
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
351
- ):
352
- r"""
353
- Encodes the prompt into text encoder hidden states.
354
- Args:
355
- prompt (`str` or `List[str]`, *optional*):
356
- prompt to be encoded
357
- device: (`torch.device`):
358
- torch device
359
- num_images_per_prompt (`int`):
360
- number of images that should be generated per prompt
361
- do_classifier_free_guidance (`bool`):
362
- whether to use classifier free guidance or not
363
- negative_prompt (`str` or `List[str]`, *optional*):
364
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
365
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
366
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
367
- prompt_embeds (`torch.FloatTensor`, *optional*):
368
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
369
- provided, text embeddings will be generated from `prompt` input argument.
370
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
371
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
372
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
373
- argument.
374
- """
375
- if prompt is not None and isinstance(prompt, str):
376
- batch_size = 1
377
- elif prompt is not None and isinstance(prompt, list):
378
- batch_size = len(prompt)
379
- else:
380
- batch_size = prompt_embeds.shape[0]
381
-
382
- if prompt_embeds is None:
383
- text_inputs = self.tokenizer(
384
- prompt,
385
- padding="max_length",
386
- max_length=self.tokenizer.model_max_length,
387
- truncation=True,
388
- return_tensors="pt",
389
- )
390
- text_input_ids = text_inputs.input_ids
391
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
392
-
393
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
394
- text_input_ids, untruncated_ids
395
- ):
396
- removed_text = self.tokenizer.batch_decode(
397
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
398
- )
399
- logger.warning(
400
- "The following part of your input was truncated because CLIP can only handle sequences up to"
401
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
402
- )
403
-
404
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
405
- attention_mask = text_inputs.attention_mask.to(device)
406
- else:
407
- attention_mask = None
408
-
409
- prompt_embeds = self.text_encoder(
410
- text_input_ids.to(device),
411
- attention_mask=attention_mask,
412
- )
413
- prompt_embeds = prompt_embeds[0]
414
-
415
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
416
-
417
- bs_embed, seq_len, _ = prompt_embeds.shape
418
- # duplicate text embeddings for each generation per prompt, using mps friendly method
419
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
420
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
421
-
422
- # get unconditional embeddings for classifier free guidance
423
- if do_classifier_free_guidance and negative_prompt_embeds is None:
424
- uncond_tokens: List[str]
425
- if negative_prompt is None:
426
- uncond_tokens = [""] * batch_size
427
- elif type(prompt) is not type(negative_prompt):
428
- raise TypeError(
429
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
430
- f" {type(prompt)}."
431
- )
432
- elif isinstance(negative_prompt, str):
433
- uncond_tokens = [negative_prompt]
434
- elif batch_size != len(negative_prompt):
435
- raise ValueError(
436
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
437
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
438
- " the batch size of `prompt`."
439
- )
440
- else:
441
- uncond_tokens = negative_prompt
442
-
443
- max_length = prompt_embeds.shape[1]
444
- uncond_input = self.tokenizer(
445
- uncond_tokens,
446
- padding="max_length",
447
- max_length=max_length,
448
- truncation=True,
449
- return_tensors="pt",
450
- )
451
-
452
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
453
- attention_mask = uncond_input.attention_mask.to(device)
454
- else:
455
- attention_mask = None
456
-
457
- negative_prompt_embeds = self.text_encoder(
458
- uncond_input.input_ids.to(device),
459
- attention_mask=attention_mask,
460
- )
461
- negative_prompt_embeds = negative_prompt_embeds[0]
462
-
463
- if do_classifier_free_guidance:
464
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
465
- seq_len = negative_prompt_embeds.shape[1]
466
-
467
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
468
-
469
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
470
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
471
-
472
- # For classifier free guidance, we need to do two forward passes.
473
- # Here we concatenate the unconditional and text embeddings into a single batch
474
- # to avoid doing two forward passes
475
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
476
-
477
- return prompt_embeds
478
-
479
- def run_safety_checker(self, image, device, dtype):
480
- if self.safety_checker is not None:
481
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
482
- image, has_nsfw_concept = self.safety_checker(
483
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
484
- )
485
- else:
486
- has_nsfw_concept = None
487
- return image, has_nsfw_concept
488
-
489
- def decode_latents(self, latents):
490
- latents = 1 / self.vae.config.scaling_factor * latents
491
- image = self.vae.decode(latents).sample
492
- image = (image / 2 + 0.5).clamp(0, 1)
493
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
494
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
495
- return image
496
-
497
- def prepare_extra_step_kwargs(self, generator, eta):
498
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
499
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
500
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
501
- # and should be between [0, 1]
502
-
503
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
504
- extra_step_kwargs = {}
505
- if accepts_eta:
506
- extra_step_kwargs["eta"] = eta
507
-
508
- # check if the scheduler accepts generator
509
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
510
- if accepts_generator:
511
- extra_step_kwargs["generator"] = generator
512
- return extra_step_kwargs
513
-
514
- def check_inputs(
515
- self,
516
- prompt,
517
- image,
518
- mask_image,
519
- controlnet_conditioning_image,
520
- height,
521
- width,
522
- callback_steps,
523
- negative_prompt=None,
524
- prompt_embeds=None,
525
- negative_prompt_embeds=None,
526
- strength=None,
527
- ):
528
- if height % 8 != 0 or width % 8 != 0:
529
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
530
-
531
- if (callback_steps is None) or (
532
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
533
- ):
534
- raise ValueError(
535
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
536
- f" {type(callback_steps)}."
537
- )
538
-
539
- if prompt is not None and prompt_embeds is not None:
540
- raise ValueError(
541
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
542
- " only forward one of the two."
543
- )
544
- elif prompt is None and prompt_embeds is None:
545
- raise ValueError(
546
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
547
- )
548
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
549
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
550
-
551
- if negative_prompt is not None and negative_prompt_embeds is not None:
552
- raise ValueError(
553
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
554
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
555
- )
556
-
557
- if prompt_embeds is not None and negative_prompt_embeds is not None:
558
- if prompt_embeds.shape != negative_prompt_embeds.shape:
559
- raise ValueError(
560
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
561
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
562
- f" {negative_prompt_embeds.shape}."
563
- )
564
-
565
- controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
566
- controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
567
- controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
568
- controlnet_conditioning_image[0], PIL.Image.Image
569
- )
570
- controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
571
- controlnet_conditioning_image[0], torch.Tensor
572
- )
573
-
574
- if (
575
- not controlnet_cond_image_is_pil
576
- and not controlnet_cond_image_is_tensor
577
- and not controlnet_cond_image_is_pil_list
578
- and not controlnet_cond_image_is_tensor_list
579
- ):
580
- raise TypeError(
581
- "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
582
- )
583
-
584
- if controlnet_cond_image_is_pil:
585
- controlnet_cond_image_batch_size = 1
586
- elif controlnet_cond_image_is_tensor:
587
- controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
588
- elif controlnet_cond_image_is_pil_list:
589
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
590
- elif controlnet_cond_image_is_tensor_list:
591
- controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
592
-
593
- if prompt is not None and isinstance(prompt, str):
594
- prompt_batch_size = 1
595
- elif prompt is not None and isinstance(prompt, list):
596
- prompt_batch_size = len(prompt)
597
- elif prompt_embeds is not None:
598
- prompt_batch_size = prompt_embeds.shape[0]
599
-
600
- if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
601
- raise ValueError(
602
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
603
- )
604
-
605
- if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
606
- raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
607
-
608
- if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
609
- raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
610
-
611
- if isinstance(image, torch.Tensor):
612
- if image.ndim != 3 and image.ndim != 4:
613
- raise ValueError("`image` must have 3 or 4 dimensions")
614
-
615
- if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
616
- raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
617
-
618
- if image.ndim == 3:
619
- image_batch_size = 1
620
- image_channels, image_height, image_width = image.shape
621
- elif image.ndim == 4:
622
- image_batch_size, image_channels, image_height, image_width = image.shape
623
-
624
- if mask_image.ndim == 2:
625
- mask_image_batch_size = 1
626
- mask_image_channels = 1
627
- mask_image_height, mask_image_width = mask_image.shape
628
- elif mask_image.ndim == 3:
629
- mask_image_channels = 1
630
- mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
631
- elif mask_image.ndim == 4:
632
- mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
633
-
634
- if image_channels != 3:
635
- raise ValueError("`image` must have 3 channels")
636
-
637
- if mask_image_channels != 1:
638
- raise ValueError("`mask_image` must have 1 channel")
639
-
640
- if image_batch_size != mask_image_batch_size:
641
- raise ValueError("`image` and `mask_image` mush have the same batch sizes")
642
-
643
- if image_height != mask_image_height or image_width != mask_image_width:
644
- raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
645
-
646
- if image.min() < -1 or image.max() > 1:
647
- raise ValueError("`image` should be in range [-1, 1]")
648
-
649
- if mask_image.min() < 0 or mask_image.max() > 1:
650
- raise ValueError("`mask_image` should be in range [0, 1]")
651
- else:
652
- mask_image_channels = 1
653
- image_channels = 3
654
-
655
- single_image_latent_channels = self.vae.config.latent_channels
656
-
657
- total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
658
-
659
- if total_latent_channels != self.unet.config.in_channels:
660
- raise ValueError(
661
- f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
662
- f" non inpainting latent channels: {single_image_latent_channels},"
663
- f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
664
- f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
665
- )
666
-
667
- if strength < 0 or strength > 1:
668
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
669
-
670
- def get_timesteps(self, num_inference_steps, strength, device):
671
- # get the original timestep using init_timestep
672
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
673
-
674
- t_start = max(num_inference_steps - init_timestep, 0)
675
- timesteps = self.scheduler.timesteps[t_start:]
676
-
677
- return timesteps, num_inference_steps - t_start
678
-
679
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
680
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
681
- raise ValueError(
682
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
683
- )
684
-
685
- image = image.to(device=device, dtype=dtype)
686
-
687
- batch_size = batch_size * num_images_per_prompt
688
- if isinstance(generator, list) and len(generator) != batch_size:
689
- raise ValueError(
690
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
691
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
692
- )
693
-
694
- if isinstance(generator, list):
695
- init_latents = [
696
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
697
- ]
698
- init_latents = torch.cat(init_latents, dim=0)
699
- else:
700
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
701
-
702
- init_latents = self.vae.config.scaling_factor * init_latents
703
-
704
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
705
- raise ValueError(
706
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
707
- )
708
- else:
709
- init_latents = torch.cat([init_latents], dim=0)
710
-
711
- shape = init_latents.shape
712
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
713
-
714
- # get latents
715
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
716
- latents = init_latents
717
-
718
- return latents
719
-
720
- def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
721
- # resize the mask to latents shape as we concatenate the mask to the latents
722
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
723
- # and half precision
724
- mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
725
- mask_image = mask_image.to(device=device, dtype=dtype)
726
-
727
- # duplicate mask for each generation per prompt, using mps friendly method
728
- if mask_image.shape[0] < batch_size:
729
- if not batch_size % mask_image.shape[0] == 0:
730
- raise ValueError(
731
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
732
- f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
733
- " of masks that you pass is divisible by the total requested batch size."
734
- )
735
- mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
736
-
737
- mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
738
-
739
- mask_image_latents = mask_image
740
-
741
- return mask_image_latents
742
-
743
- def prepare_masked_image_latents(
744
- self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
745
- ):
746
- masked_image = masked_image.to(device=device, dtype=dtype)
747
-
748
- # encode the mask image into latents space so we can concatenate it to the latents
749
- if isinstance(generator, list):
750
- masked_image_latents = [
751
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
752
- for i in range(batch_size)
753
- ]
754
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
755
- else:
756
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
757
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
758
-
759
- # duplicate masked_image_latents for each generation per prompt, using mps friendly method
760
- if masked_image_latents.shape[0] < batch_size:
761
- if not batch_size % masked_image_latents.shape[0] == 0:
762
- raise ValueError(
763
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
764
- f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
765
- " Make sure the number of images that you pass is divisible by the total requested batch size."
766
- )
767
- masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
768
-
769
- masked_image_latents = (
770
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
771
- )
772
-
773
- # aligning device to prevent device errors when concating it with the latent model input
774
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
775
- return masked_image_latents
776
-
777
- def _default_height_width(self, height, width, image):
778
- if isinstance(image, list):
779
- image = image[0]
780
-
781
- if height is None:
782
- if isinstance(image, PIL.Image.Image):
783
- height = image.height
784
- elif isinstance(image, torch.Tensor):
785
- height = image.shape[3]
786
-
787
- height = (height // 8) * 8 # round down to nearest multiple of 8
788
-
789
- if width is None:
790
- if isinstance(image, PIL.Image.Image):
791
- width = image.width
792
- elif isinstance(image, torch.Tensor):
793
- width = image.shape[2]
794
-
795
- width = (width // 8) * 8 # round down to nearest multiple of 8
796
-
797
- return height, width
798
-
799
- @torch.no_grad()
800
- @replace_example_docstring(EXAMPLE_DOC_STRING)
801
- def __call__(
802
- self,
803
- prompt: Union[str, List[str]] = None,
804
- image: Union[torch.Tensor, PIL.Image.Image] = None,
805
- mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
806
- controlnet_conditioning_image: Union[
807
- torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
808
- ] = None,
809
- strength: float = 0.8,
810
- height: Optional[int] = None,
811
- width: Optional[int] = None,
812
- num_inference_steps: int = 50,
813
- guidance_scale: float = 7.5,
814
- negative_prompt: Optional[Union[str, List[str]]] = None,
815
- num_images_per_prompt: Optional[int] = 1,
816
- eta: float = 0.0,
817
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
818
- latents: Optional[torch.FloatTensor] = None,
819
- prompt_embeds: Optional[torch.FloatTensor] = None,
820
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
821
- output_type: Optional[str] = "pil",
822
- return_dict: bool = True,
823
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
824
- callback_steps: int = 1,
825
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
826
- controlnet_conditioning_scale: float = 1.0,
827
- controlnet_conditioning_scale_decay: float = 0.95,
828
- controlnet_steps: int = 10,
829
- ):
830
- r"""
831
- Function invoked when calling the pipeline for generation.
832
- Args:
833
- prompt (`str` or `List[str]`, *optional*):
834
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
835
- instead.
836
- image (`torch.Tensor` or `PIL.Image.Image`):
837
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
838
- be masked out with `mask_image` and repainted according to `prompt`.
839
- mask_image (`torch.Tensor` or `PIL.Image.Image`):
840
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
841
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
842
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
843
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
844
- controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
845
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
846
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
847
- also be accepted as an image. The control image is automatically resized to fit the output image.
848
- strength (`float`, *optional*):
849
- Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
850
- will be used as a starting point, adding more noise to it the larger the `strength`. The number of
851
- denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
852
- be maximum and the denoising process will run for the full number of iterations specified in
853
- `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
854
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
855
- The height in pixels of the generated image.
856
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
857
- The width in pixels of the generated image.
858
- num_inference_steps (`int`, *optional*, defaults to 50):
859
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
860
- expense of slower inference.
861
- guidance_scale (`float`, *optional*, defaults to 7.5):
862
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
863
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
864
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
865
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
866
- usually at the expense of lower image quality.
867
- negative_prompt (`str` or `List[str]`, *optional*):
868
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
869
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
870
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
871
- num_images_per_prompt (`int`, *optional*, defaults to 1):
872
- The number of images to generate per prompt.
873
- eta (`float`, *optional*, defaults to 0.0):
874
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
875
- [`schedulers.DDIMScheduler`], will be ignored for others.
876
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
877
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
878
- to make generation deterministic.
879
- latents (`torch.FloatTensor`, *optional*):
880
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
881
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
882
- tensor will ge generated by sampling using the supplied random `generator`.
883
- prompt_embeds (`torch.FloatTensor`, *optional*):
884
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
885
- provided, text embeddings will be generated from `prompt` input argument.
886
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
887
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
888
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
889
- argument.
890
- output_type (`str`, *optional*, defaults to `"pil"`):
891
- The output format of the generate image. Choose between
892
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
893
- return_dict (`bool`, *optional*, defaults to `True`):
894
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
895
- plain tuple.
896
- callback (`Callable`, *optional*):
897
- A function that will be called every `callback_steps` steps during inference. The function will be
898
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
899
- callback_steps (`int`, *optional*, defaults to 1):
900
- The frequency at which the `callback` function will be called. If not specified, the callback will be
901
- called at every step.
902
- cross_attention_kwargs (`dict`, *optional*):
903
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
904
- `self.processor` in
905
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
906
- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
907
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
908
- to the residual in the original unet.
909
- Examples:
910
- Returns:
911
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
912
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
913
- When returning a tuple, the first element is a list with the generated images, and the second element is a
914
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
915
- (nsfw) content, according to the `safety_checker`.
916
- """
917
- # 0. Default height and width to unet
918
- height, width = self._default_height_width(height, width, controlnet_conditioning_image)
919
-
920
- # 1. Check inputs. Raise error if not correct
921
- self.check_inputs(
922
- prompt,
923
- image,
924
- mask_image,
925
- controlnet_conditioning_image,
926
- height,
927
- width,
928
- callback_steps,
929
- negative_prompt,
930
- prompt_embeds,
931
- negative_prompt_embeds,
932
- strength,
933
- )
934
-
935
- # 2. Define call parameters
936
- if prompt is not None and isinstance(prompt, str):
937
- batch_size = 1
938
- elif prompt is not None and isinstance(prompt, list):
939
- batch_size = len(prompt)
940
- else:
941
- batch_size = prompt_embeds.shape[0]
942
-
943
- device = self._execution_device
944
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
945
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
946
- # corresponds to doing no classifier free guidance.
947
- do_classifier_free_guidance = guidance_scale > 1.0
948
-
949
- # 3. Encode input prompt
950
- prompt_embeds = self._encode_prompt(
951
- prompt,
952
- device,
953
- num_images_per_prompt,
954
- do_classifier_free_guidance,
955
- negative_prompt,
956
- prompt_embeds=prompt_embeds,
957
- negative_prompt_embeds=negative_prompt_embeds,
958
- )
959
-
960
- # 4. Prepare mask, image, and controlnet_conditioning_image
961
- image = prepare_image(image)
962
-
963
- mask_image = prepare_mask_image(mask_image)
964
-
965
- controlnet_conditioning_image = prepare_controlnet_conditioning_image(
966
- controlnet_conditioning_image,
967
- width,
968
- height,
969
- batch_size * num_images_per_prompt,
970
- num_images_per_prompt,
971
- device,
972
- self.controlnet.dtype,
973
- )
974
-
975
- masked_image = image * (mask_image < 0.5)
976
-
977
- # 5. Prepare timesteps
978
- self.scheduler.set_timesteps(num_inference_steps, device=device)
979
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
980
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
981
-
982
- # 6. Prepare latent variables
983
- latents = self.prepare_latents(
984
- image,
985
- latent_timestep,
986
- batch_size,
987
- num_images_per_prompt,
988
- prompt_embeds.dtype,
989
- device,
990
- generator,
991
- )
992
-
993
- mask_image_latents = self.prepare_mask_latents(
994
- mask_image,
995
- batch_size * num_images_per_prompt,
996
- height,
997
- width,
998
- prompt_embeds.dtype,
999
- device,
1000
- do_classifier_free_guidance,
1001
- )
1002
-
1003
- masked_image_latents = self.prepare_masked_image_latents(
1004
- masked_image,
1005
- batch_size * num_images_per_prompt,
1006
- height,
1007
- width,
1008
- prompt_embeds.dtype,
1009
- device,
1010
- generator,
1011
- do_classifier_free_guidance,
1012
- )
1013
-
1014
- if do_classifier_free_guidance:
1015
- controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
1016
-
1017
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1018
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1019
-
1020
- # 8. Denoising loop
1021
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1022
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1023
- for i, t in enumerate(timesteps):
1024
- # expand the latents if we are doing classifier free guidance
1025
- non_inpainting_latent_model_input = (
1026
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1027
- )
1028
-
1029
- non_inpainting_latent_model_input = self.scheduler.scale_model_input(
1030
- non_inpainting_latent_model_input, t
1031
- )
1032
-
1033
- inpainting_latent_model_input = torch.cat(
1034
- [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
1035
- )
1036
-
1037
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1038
- non_inpainting_latent_model_input,
1039
- t,
1040
- encoder_hidden_states=prompt_embeds,
1041
- controlnet_cond=controlnet_conditioning_image,
1042
- return_dict=False,
1043
- )
1044
- if i <= controlnet_steps:
1045
- conditioning_scale = (controlnet_conditioning_scale * controlnet_conditioning_scale_decay ** i)
1046
- else:
1047
- conditioning_scale = 0.0
1048
-
1049
- down_block_res_samples = [
1050
- down_block_res_sample * conditioning_scale
1051
- for down_block_res_sample in down_block_res_samples
1052
- ]
1053
- mid_block_res_sample *= conditioning_scale
1054
-
1055
- # predict the noise residual
1056
- noise_pred = self.unet(
1057
- inpainting_latent_model_input,
1058
- t,
1059
- encoder_hidden_states=prompt_embeds,
1060
- cross_attention_kwargs=cross_attention_kwargs,
1061
- down_block_additional_residuals=down_block_res_samples,
1062
- mid_block_additional_residual=mid_block_res_sample,
1063
- ).sample
1064
-
1065
- # perform guidance
1066
- if do_classifier_free_guidance:
1067
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1068
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1069
-
1070
- # compute the previous noisy sample x_t -> x_t-1
1071
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1072
-
1073
- # call the callback, if provided
1074
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1075
- progress_bar.update()
1076
- if callback is not None and i % callback_steps == 0:
1077
- callback(i, t, latents)
1078
-
1079
- # If we do sequential model offloading, let's offload unet and controlnet
1080
- # manually for max memory savings
1081
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1082
- self.unet.to("cpu")
1083
- self.controlnet.to("cpu")
1084
- torch.cuda.empty_cache()
1085
-
1086
- if output_type == "latent":
1087
- image = latents
1088
- has_nsfw_concept = None
1089
- elif output_type == "pil":
1090
- # 8. Post-processing
1091
- image = self.decode_latents(latents)
1092
-
1093
- # 9. Run safety checker
1094
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1095
-
1096
- # 10. Convert to PIL
1097
- image = self.numpy_to_pil(image)
1098
- else:
1099
- # 8. Post-processing
1100
- image = self.decode_latents(latents)
1101
-
1102
- # 9. Run safety checker
1103
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1104
-
1105
- # Offload last model to CPU
1106
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1107
- self.final_offload_hook.offload()
1108
-
1109
- if not return_dict:
1110
- return (image, has_nsfw_concept)
1111
-
1112
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,6 +1,7 @@
1
 
2
  import numpy as np
3
  from PIL import Image
 
4
 
5
  def image_to_byte_array(image: Image) -> bytes:
6
  # BytesIO is a fake file stored in memory
 
1
 
2
  import numpy as np
3
  from PIL import Image
4
+ import io
5
 
6
  def image_to_byte_array(image: Image) -> bytes:
7
  # BytesIO is a fake file stored in memory