zxcgqq commited on
Commit
7473a1e
·
1 Parent(s): 5372321

Create visual_foundation_models.py

Browse files
Files changed (1) hide show
  1. visual_foundation_models.py +892 -0
visual_foundation_models.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
2
+ from diffusers import EulerAncestralDiscreteScheduler
3
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
4
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
7
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
8
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
9
+
10
+ import os
11
+ import random
12
+ import torch
13
+ import cv2
14
+ import uuid
15
+ from PIL import Image, ImageOps
16
+ import numpy as np
17
+ from pytorch_lightning import seed_everything
18
+ import math
19
+
20
+ from langchain.llms.openai import OpenAI
21
+
22
+ def prompts(name, description):
23
+ def decorator(func):
24
+ func.name = name
25
+ func.description = description
26
+ return func
27
+
28
+ return decorator
29
+
30
+ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100):
31
+ new_size = new_image.size
32
+ old_size = old_image.size
33
+ easy_img = np.array(new_image)
34
+ gt_img_array = np.array(old_image)
35
+ pos_w = (new_size[0] - old_size[0]) // 2
36
+ pos_h = (new_size[1] - old_size[1]) // 2
37
+
38
+ kernel_h = cv2.getGaussianKernel(old_size[1], old_size[1] * sigma)
39
+ kernel_w = cv2.getGaussianKernel(old_size[0], old_size[0] * sigma)
40
+ kernel = np.multiply(kernel_h, np.transpose(kernel_w))
41
+
42
+ kernel[steps:-steps, steps:-steps] = 1
43
+ kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, steps - 1]
44
+ kernel[:steps, -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)]
45
+ kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, steps - 1]
46
+ kernel[-steps:, -steps:] = kernel[-steps:, -steps:] / kernel[-steps, -steps]
47
+ kernel = np.expand_dims(kernel, 2)
48
+ kernel = np.repeat(kernel, 3, 2)
49
+
50
+ weight = np.linspace(0, 1, steps)
51
+ top = np.expand_dims(weight, 1)
52
+ top = np.repeat(top, old_size[0] - 2 * steps, 1)
53
+ top = np.expand_dims(top, 2)
54
+ top = np.repeat(top, 3, 2)
55
+
56
+ weight = np.linspace(1, 0, steps)
57
+ down = np.expand_dims(weight, 1)
58
+ down = np.repeat(down, old_size[0] - 2 * steps, 1)
59
+ down = np.expand_dims(down, 2)
60
+ down = np.repeat(down, 3, 2)
61
+
62
+ weight = np.linspace(0, 1, steps)
63
+ left = np.expand_dims(weight, 0)
64
+ left = np.repeat(left, old_size[1] - 2 * steps, 0)
65
+ left = np.expand_dims(left, 2)
66
+ left = np.repeat(left, 3, 2)
67
+
68
+ weight = np.linspace(1, 0, steps)
69
+ right = np.expand_dims(weight, 0)
70
+ right = np.repeat(right, old_size[1] - 2 * steps, 0)
71
+ right = np.expand_dims(right, 2)
72
+ right = np.repeat(right, 3, 2)
73
+
74
+ kernel[:steps, steps:-steps] = top
75
+ kernel[-steps:, steps:-steps] = down
76
+ kernel[steps:-steps, :steps] = left
77
+ kernel[steps:-steps, -steps:] = right
78
+
79
+ pt_gt_img = easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]]
80
+ gaussian_gt_img = kernel * gt_img_array + (1 - kernel) * pt_gt_img # gt img with blur img
81
+ gaussian_gt_img = gaussian_gt_img.astype(np.int64)
82
+ easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]] = gaussian_gt_img
83
+ gaussian_img = Image.fromarray(easy_img)
84
+ return gaussian_img
85
+
86
+ def get_new_image_name(org_img_name, func_name="update"):
87
+ head_tail = os.path.split(org_img_name)
88
+ head = head_tail[0]
89
+ tail = head_tail[1]
90
+ name_split = tail.split('.')[0].split('_')
91
+ this_new_uuid = str(uuid.uuid4())[0:4]
92
+ if len(name_split) == 1:
93
+ most_org_file_name = name_split[0]
94
+ recent_prev_file_name = name_split[0]
95
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
96
+ else:
97
+ assert len(name_split) == 4
98
+ most_org_file_name = name_split[3]
99
+ recent_prev_file_name = name_split[0]
100
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
101
+ return os.path.join(head, new_file_name)
102
+
103
+
104
+ class MaskFormer:
105
+ def __init__(self, device):
106
+ print(f"Initializing MaskFormer to {device}")
107
+ self.device = device
108
+ self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
109
+ self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
110
+
111
+ def inference(self, image_path, text):
112
+ threshold = 0.5
113
+ min_area = 0.02
114
+ padding = 20
115
+ original_image = Image.open(image_path)
116
+ image = original_image.resize((512, 512))
117
+ inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
118
+ with torch.no_grad():
119
+ outputs = self.model(**inputs)
120
+ mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
121
+ area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
122
+ if area_ratio < min_area:
123
+ return None
124
+ true_indices = np.argwhere(mask)
125
+ mask_array = np.zeros_like(mask, dtype=bool)
126
+ for idx in true_indices:
127
+ padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
128
+ mask_array[padded_slice] = True
129
+ visual_mask = (mask_array * 255).astype(np.uint8)
130
+ image_mask = Image.fromarray(visual_mask)
131
+ return image_mask.resize(original_image.size)
132
+
133
+
134
+ class ImageEditing:
135
+ def __init__(self, device):
136
+ print(f"Initializing ImageEditing to {device}")
137
+ self.device = device
138
+ self.mask_former = MaskFormer(device=self.device)
139
+ self.revision = 'fp16' if 'cuda' in device else None
140
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
141
+ self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
142
+ "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
143
+
144
+ @prompts(name="Remove Something From The Photo",
145
+ description="useful when you want to remove and object or something from the photo "
146
+ "from its description or location. "
147
+ "The input to this tool should be a comma separated string of two, "
148
+ "representing the image_path and the object need to be removed. ")
149
+ def inference_remove(self, inputs):
150
+ image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
151
+ return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
152
+
153
+ @prompts(name="Replace Something From The Photo",
154
+ description="useful when you want to replace an object from the object description or "
155
+ "location with another object from its description. "
156
+ "The input to this tool should be a comma separated string of three, "
157
+ "representing the image_path, the object to be replaced, the object to be replaced with ")
158
+ def inference_replace(self, inputs):
159
+ image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
160
+ original_image = Image.open(image_path)
161
+ original_size = original_image.size
162
+ mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
163
+ updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)),
164
+ mask_image=mask_image.resize((512, 512))).images[0]
165
+ updated_image_path = get_new_image_name(image_path, func_name="replace-something")
166
+ updated_image = updated_image.resize(original_size)
167
+ updated_image.save(updated_image_path)
168
+ print(
169
+ f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
170
+ f"Output Image: {updated_image_path}")
171
+ return updated_image_path
172
+
173
+
174
+ class InstructPix2Pix:
175
+ def __init__(self, device):
176
+ print(f"Initializing InstructPix2Pix to {device}")
177
+ self.device = device
178
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
179
+ self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
180
+ safety_checker=None,
181
+ torch_dtype=self.torch_dtype).to(device)
182
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
183
+
184
+ @prompts(name="Instruct Image Using Text",
185
+ description="useful when you want to the style of the image to be like the text. "
186
+ "like: make it look like a painting. or make it like a robot. "
187
+ "The input to this tool should be a comma separated string of two, "
188
+ "representing the image_path and the text. ")
189
+ def inference(self, inputs):
190
+ """Change style of image."""
191
+ print("===>Starting InstructPix2Pix Inference")
192
+ image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
193
+ original_image = Image.open(image_path)
194
+ image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
195
+ updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
196
+ image.save(updated_image_path)
197
+ print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
198
+ f"Output Image: {updated_image_path}")
199
+ return updated_image_path
200
+
201
+
202
+ class Text2Image:
203
+ def __init__(self, device):
204
+ print(f"Initializing Text2Image to {device}")
205
+ self.device = device
206
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
207
+ self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
208
+ torch_dtype=self.torch_dtype)
209
+ self.pipe.to(device)
210
+ self.a_prompt = 'best quality, extremely detailed'
211
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
212
+ 'fewer digits, cropped, worst quality, low quality'
213
+
214
+ @prompts(name="Generate Image From User Input Text",
215
+ description="useful when you want to generate an image from a user input text and save it to a file. "
216
+ "like: generate an image of an object or something, or generate an image that includes some objects. "
217
+ "The input to this tool should be a string, representing the text used to generate image. ")
218
+ def inference(self, text):
219
+ image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
220
+ prompt = text + ', ' + self.a_prompt
221
+ image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
222
+ image.save(image_filename)
223
+ print(
224
+ f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
225
+ return image_filename
226
+
227
+
228
+ class ImageCaptioning:
229
+ def __init__(self, device):
230
+ print(f"Initializing ImageCaptioning to {device}")
231
+ self.device = device
232
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
233
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
234
+ self.model = BlipForConditionalGeneration.from_pretrained(
235
+ "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
236
+
237
+ @prompts(name="Get Photo Description",
238
+ description="useful when you want to know what is inside the photo. receives image_path as input. "
239
+ "The input to this tool should be a string, representing the image_path. ")
240
+ def inference(self, image_path):
241
+ inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
242
+ out = self.model.generate(**inputs)
243
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
244
+ print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
245
+ return captions
246
+
247
+
248
+ class Image2Canny:
249
+ def __init__(self, device):
250
+ print("Initializing Image2Canny")
251
+ self.low_threshold = 100
252
+ self.high_threshold = 200
253
+
254
+ @prompts(name="Edge Detection On Image",
255
+ description="useful when you want to detect the edge of the image. "
256
+ "like: detect the edges of this image, or canny detection on image, "
257
+ "or perform edge detection on this image, or detect the canny image of this image. "
258
+ "The input to this tool should be a string, representing the image_path")
259
+ def inference(self, inputs):
260
+ image = Image.open(inputs)
261
+ image = np.array(image)
262
+ canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
263
+ canny = canny[:, :, None]
264
+ canny = np.concatenate([canny, canny, canny], axis=2)
265
+ canny = Image.fromarray(canny)
266
+ updated_image_path = get_new_image_name(inputs, func_name="edge")
267
+ canny.save(updated_image_path)
268
+ print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}")
269
+ return updated_image_path
270
+
271
+
272
+ class CannyText2Image:
273
+ def __init__(self, device):
274
+ print(f"Initializing CannyText2Image to {device}")
275
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
276
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny",
277
+ torch_dtype=self.torch_dtype)
278
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
279
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
280
+ torch_dtype=self.torch_dtype)
281
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
282
+ self.pipe.to(device)
283
+ self.seed = -1
284
+ self.a_prompt = 'best quality, extremely detailed'
285
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
286
+ 'fewer digits, cropped, worst quality, low quality'
287
+
288
+ @prompts(name="Generate Image Condition On Canny Image",
289
+ description="useful when you want to generate a new real image from both the user description and a canny image."
290
+ " like: generate a real image of a object or something from this canny image,"
291
+ " or generate a new real image of a object or something from this edge image. "
292
+ "The input to this tool should be a comma separated string of two, "
293
+ "representing the image_path and the user description. ")
294
+ def inference(self, inputs):
295
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
296
+ image = Image.open(image_path)
297
+ self.seed = random.randint(0, 65535)
298
+ seed_everything(self.seed)
299
+ prompt = f'{instruct_text}, {self.a_prompt}'
300
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
301
+ guidance_scale=9.0).images[0]
302
+ updated_image_path = get_new_image_name(image_path, func_name="canny2image")
303
+ image.save(updated_image_path)
304
+ print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, "
305
+ f"Output Text: {updated_image_path}")
306
+ return updated_image_path
307
+
308
+
309
+ class Image2Line:
310
+ def __init__(self, device):
311
+ print("Initializing Image2Line")
312
+ self.detector = MLSDdetector.from_pretrained('lllyasviel/ControlNet')
313
+
314
+ @prompts(name="Line Detection On Image",
315
+ description="useful when you want to detect the straight line of the image. "
316
+ "like: detect the straight lines of this image, or straight line detection on image, "
317
+ "or perform straight line detection on this image, or detect the straight line image of this image. "
318
+ "The input to this tool should be a string, representing the image_path")
319
+ def inference(self, inputs):
320
+ image = Image.open(inputs)
321
+ mlsd = self.detector(image)
322
+ updated_image_path = get_new_image_name(inputs, func_name="line-of")
323
+ mlsd.save(updated_image_path)
324
+ print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}")
325
+ return updated_image_path
326
+
327
+
328
+ class LineText2Image:
329
+ def __init__(self, device):
330
+ print(f"Initializing LineText2Image to {device}")
331
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
332
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd",
333
+ torch_dtype=self.torch_dtype)
334
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
335
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
336
+ torch_dtype=self.torch_dtype
337
+ )
338
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
339
+ self.pipe.to(device)
340
+ self.seed = -1
341
+ self.a_prompt = 'best quality, extremely detailed'
342
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
343
+ 'fewer digits, cropped, worst quality, low quality'
344
+
345
+ @prompts(name="Generate Image Condition On Line Image",
346
+ description="useful when you want to generate a new real image from both the user description "
347
+ "and a straight line image. "
348
+ "like: generate a real image of a object or something from this straight line image, "
349
+ "or generate a new real image of a object or something from this straight lines. "
350
+ "The input to this tool should be a comma separated string of two, "
351
+ "representing the image_path and the user description. ")
352
+ def inference(self, inputs):
353
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
354
+ image = Image.open(image_path)
355
+ self.seed = random.randint(0, 65535)
356
+ seed_everything(self.seed)
357
+ prompt = f'{instruct_text}, {self.a_prompt}'
358
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
359
+ guidance_scale=9.0).images[0]
360
+ updated_image_path = get_new_image_name(image_path, func_name="line2image")
361
+ image.save(updated_image_path)
362
+ print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, "
363
+ f"Output Text: {updated_image_path}")
364
+ return updated_image_path
365
+
366
+
367
+ class Image2Hed:
368
+ def __init__(self, device):
369
+ print("Initializing Image2Hed")
370
+ self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
371
+
372
+ @prompts(name="Hed Detection On Image",
373
+ description="useful when you want to detect the soft hed boundary of the image. "
374
+ "like: detect the soft hed boundary of this image, or hed boundary detection on image, "
375
+ "or perform hed boundary detection on this image, or detect soft hed boundary image of this image. "
376
+ "The input to this tool should be a string, representing the image_path")
377
+ def inference(self, inputs):
378
+ image = Image.open(inputs)
379
+ hed = self.detector(image)
380
+ updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
381
+ hed.save(updated_image_path)
382
+ print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}")
383
+ return updated_image_path
384
+
385
+
386
+ class HedText2Image:
387
+ def __init__(self, device):
388
+ print(f"Initializing HedText2Image to {device}")
389
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
390
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed",
391
+ torch_dtype=self.torch_dtype)
392
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
393
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
394
+ torch_dtype=self.torch_dtype
395
+ )
396
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
397
+ self.pipe.to(device)
398
+ self.seed = -1
399
+ self.a_prompt = 'best quality, extremely detailed'
400
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
401
+ 'fewer digits, cropped, worst quality, low quality'
402
+
403
+ @prompts(name="Generate Image Condition On Soft Hed Boundary Image",
404
+ description="useful when you want to generate a new real image from both the user description "
405
+ "and a soft hed boundary image. "
406
+ "like: generate a real image of a object or something from this soft hed boundary image, "
407
+ "or generate a new real image of a object or something from this hed boundary. "
408
+ "The input to this tool should be a comma separated string of two, "
409
+ "representing the image_path and the user description")
410
+ def inference(self, inputs):
411
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
412
+ image = Image.open(image_path)
413
+ self.seed = random.randint(0, 65535)
414
+ seed_everything(self.seed)
415
+ prompt = f'{instruct_text}, {self.a_prompt}'
416
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
417
+ guidance_scale=9.0).images[0]
418
+ updated_image_path = get_new_image_name(image_path, func_name="hed2image")
419
+ image.save(updated_image_path)
420
+ print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, "
421
+ f"Output Image: {updated_image_path}")
422
+ return updated_image_path
423
+
424
+
425
+ class Image2Scribble:
426
+ def __init__(self, device):
427
+ print("Initializing Image2Scribble")
428
+ self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
429
+
430
+ @prompts(name="Sketch Detection On Image",
431
+ description="useful when you want to generate a scribble of the image. "
432
+ "like: generate a scribble of this image, or generate a sketch from this image, "
433
+ "detect the sketch from this image. "
434
+ "The input to this tool should be a string, representing the image_path")
435
+ def inference(self, inputs):
436
+ image = Image.open(inputs)
437
+ scribble = self.detector(image, scribble=True)
438
+ updated_image_path = get_new_image_name(inputs, func_name="scribble")
439
+ scribble.save(updated_image_path)
440
+ print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
441
+ return updated_image_path
442
+
443
+
444
+ class ScribbleText2Image:
445
+ def __init__(self, device):
446
+ print(f"Initializing ScribbleText2Image to {device}")
447
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
448
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble",
449
+ torch_dtype=self.torch_dtype)
450
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
451
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
452
+ torch_dtype=self.torch_dtype
453
+ )
454
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
455
+ self.pipe.to(device)
456
+ self.seed = -1
457
+ self.a_prompt = 'best quality, extremely detailed'
458
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
459
+ 'fewer digits, cropped, worst quality, low quality'
460
+
461
+ @prompts(name="Generate Image Condition On Sketch Image",
462
+ description="useful when you want to generate a new real image from both the user description and "
463
+ "a scribble image or a sketch image. "
464
+ "The input to this tool should be a comma separated string of two, "
465
+ "representing the image_path and the user description")
466
+ def inference(self, inputs):
467
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
468
+ image = Image.open(image_path)
469
+ self.seed = random.randint(0, 65535)
470
+ seed_everything(self.seed)
471
+ prompt = f'{instruct_text}, {self.a_prompt}'
472
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
473
+ guidance_scale=9.0).images[0]
474
+ updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
475
+ image.save(updated_image_path)
476
+ print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
477
+ f"Output Image: {updated_image_path}")
478
+ return updated_image_path
479
+
480
+
481
+ class Image2Pose:
482
+ def __init__(self, device):
483
+ print("Initializing Image2Pose")
484
+ self.detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
485
+
486
+ @prompts(name="Pose Detection On Image",
487
+ description="useful when you want to detect the human pose of the image. "
488
+ "like: generate human poses of this image, or generate a pose image from this image. "
489
+ "The input to this tool should be a string, representing the image_path")
490
+ def inference(self, inputs):
491
+ image = Image.open(inputs)
492
+ pose = self.detector(image)
493
+ updated_image_path = get_new_image_name(inputs, func_name="human-pose")
494
+ pose.save(updated_image_path)
495
+ print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
496
+ return updated_image_path
497
+
498
+
499
+ class PoseText2Image:
500
+ def __init__(self, device):
501
+ print(f"Initializing PoseText2Image to {device}")
502
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
503
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose",
504
+ torch_dtype=self.torch_dtype)
505
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
506
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
507
+ torch_dtype=self.torch_dtype)
508
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
509
+ self.pipe.to(device)
510
+ self.num_inference_steps = 20
511
+ self.seed = -1
512
+ self.unconditional_guidance_scale = 9.0
513
+ self.a_prompt = 'best quality, extremely detailed'
514
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
515
+ ' fewer digits, cropped, worst quality, low quality'
516
+
517
+ @prompts(name="Generate Image Condition On Pose Image",
518
+ description="useful when you want to generate a new real image from both the user description "
519
+ "and a human pose image. "
520
+ "like: generate a real image of a human from this human pose image, "
521
+ "or generate a new real image of a human from this pose. "
522
+ "The input to this tool should be a comma separated string of two, "
523
+ "representing the image_path and the user description")
524
+ def inference(self, inputs):
525
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
526
+ image = Image.open(image_path)
527
+ self.seed = random.randint(0, 65535)
528
+ seed_everything(self.seed)
529
+ prompt = f'{instruct_text}, {self.a_prompt}'
530
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
531
+ guidance_scale=9.0).images[0]
532
+ updated_image_path = get_new_image_name(image_path, func_name="pose2image")
533
+ image.save(updated_image_path)
534
+ print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, "
535
+ f"Output Image: {updated_image_path}")
536
+ return updated_image_path
537
+
538
+
539
+ class Image2Seg:
540
+ def __init__(self, device):
541
+ print("Initializing Image2Seg")
542
+ self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
543
+ self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
544
+ self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
545
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
546
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
547
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
548
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
549
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
550
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
551
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
552
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
553
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
554
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
555
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
556
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
557
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
558
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
559
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
560
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
561
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
562
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
563
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
564
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
565
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
566
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
567
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
568
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
569
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
570
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
571
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
572
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
573
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
574
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
575
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
576
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
577
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
578
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
579
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
580
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
581
+ [102, 255, 0], [92, 0, 255]]
582
+
583
+ @prompts(name="Segmentation On Image",
584
+ description="useful when you want to detect segmentations of the image. "
585
+ "like: segment this image, or generate segmentations on this image, "
586
+ "or perform segmentation on this image. "
587
+ "The input to this tool should be a string, representing the image_path")
588
+ def inference(self, inputs):
589
+ image = Image.open(inputs)
590
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
591
+ with torch.no_grad():
592
+ outputs = self.image_segmentor(pixel_values)
593
+ seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
594
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
595
+ palette = np.array(self.ade_palette)
596
+ for label, color in enumerate(palette):
597
+ color_seg[seg == label, :] = color
598
+ color_seg = color_seg.astype(np.uint8)
599
+ segmentation = Image.fromarray(color_seg)
600
+ updated_image_path = get_new_image_name(inputs, func_name="segmentation")
601
+ segmentation.save(updated_image_path)
602
+ print(f"\nProcessed Image2Seg, Input Image: {inputs}, Output Pose: {updated_image_path}")
603
+ return updated_image_path
604
+
605
+
606
+ class SegText2Image:
607
+ def __init__(self, device):
608
+ print(f"Initializing SegText2Image to {device}")
609
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
610
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg",
611
+ torch_dtype=self.torch_dtype)
612
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
613
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
614
+ torch_dtype=self.torch_dtype)
615
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
616
+ self.pipe.to(device)
617
+ self.seed = -1
618
+ self.a_prompt = 'best quality, extremely detailed'
619
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
620
+ ' fewer digits, cropped, worst quality, low quality'
621
+
622
+ @prompts(name="Generate Image Condition On Segmentations",
623
+ description="useful when you want to generate a new real image from both the user description and segmentations. "
624
+ "like: generate a real image of a object or something from this segmentation image, "
625
+ "or generate a new real image of a object or something from these segmentations. "
626
+ "The input to this tool should be a comma separated string of two, "
627
+ "representing the image_path and the user description")
628
+ def inference(self, inputs):
629
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
630
+ image = Image.open(image_path)
631
+ self.seed = random.randint(0, 65535)
632
+ seed_everything(self.seed)
633
+ prompt = f'{instruct_text}, {self.a_prompt}'
634
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
635
+ guidance_scale=9.0).images[0]
636
+ updated_image_path = get_new_image_name(image_path, func_name="segment2image")
637
+ image.save(updated_image_path)
638
+ print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, "
639
+ f"Output Image: {updated_image_path}")
640
+ return updated_image_path
641
+
642
+
643
+ class Image2Depth:
644
+ def __init__(self, device):
645
+ print("Initializing Image2Depth")
646
+ self.depth_estimator = pipeline('depth-estimation')
647
+
648
+ @prompts(name="Predict Depth On Image",
649
+ description="useful when you want to detect depth of the image. like: generate the depth from this image, "
650
+ "or detect the depth map on this image, or predict the depth for this image. "
651
+ "The input to this tool should be a string, representing the image_path")
652
+ def inference(self, inputs):
653
+ image = Image.open(inputs)
654
+ depth = self.depth_estimator(image)['depth']
655
+ depth = np.array(depth)
656
+ depth = depth[:, :, None]
657
+ depth = np.concatenate([depth, depth, depth], axis=2)
658
+ depth = Image.fromarray(depth)
659
+ updated_image_path = get_new_image_name(inputs, func_name="depth")
660
+ depth.save(updated_image_path)
661
+ print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}")
662
+ return updated_image_path
663
+
664
+
665
+ class DepthText2Image:
666
+ def __init__(self, device):
667
+ print(f"Initializing DepthText2Image to {device}")
668
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
669
+ self.controlnet = ControlNetModel.from_pretrained(
670
+ "fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype)
671
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
672
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
673
+ torch_dtype=self.torch_dtype)
674
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
675
+ self.pipe.to(device)
676
+ self.seed = -1
677
+ self.a_prompt = 'best quality, extremely detailed'
678
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
679
+ ' fewer digits, cropped, worst quality, low quality'
680
+
681
+ @prompts(name="Generate Image Condition On Depth",
682
+ description="useful when you want to generate a new real image from both the user description and depth image. "
683
+ "like: generate a real image of a object or something from this depth image, "
684
+ "or generate a new real image of a object or something from the depth map. "
685
+ "The input to this tool should be a comma separated string of two, "
686
+ "representing the image_path and the user description")
687
+ def inference(self, inputs):
688
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
689
+ image = Image.open(image_path)
690
+ self.seed = random.randint(0, 65535)
691
+ seed_everything(self.seed)
692
+ prompt = f'{instruct_text}, {self.a_prompt}'
693
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
694
+ guidance_scale=9.0).images[0]
695
+ updated_image_path = get_new_image_name(image_path, func_name="depth2image")
696
+ image.save(updated_image_path)
697
+ print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, "
698
+ f"Output Image: {updated_image_path}")
699
+ return updated_image_path
700
+
701
+
702
+ class Image2Normal:
703
+ def __init__(self, device):
704
+ print("Initializing Image2Normal")
705
+ self.depth_estimator = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas")
706
+ self.bg_threhold = 0.4
707
+
708
+ @prompts(name="Predict Normal Map On Image",
709
+ description="useful when you want to detect norm map of the image. "
710
+ "like: generate normal map from this image, or predict normal map of this image. "
711
+ "The input to this tool should be a string, representing the image_path")
712
+ def inference(self, inputs):
713
+ image = Image.open(inputs)
714
+ original_size = image.size
715
+ image = self.depth_estimator(image)['predicted_depth'][0]
716
+ image = image.numpy()
717
+ image_depth = image.copy()
718
+ image_depth -= np.min(image_depth)
719
+ image_depth /= np.max(image_depth)
720
+ x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
721
+ x[image_depth < self.bg_threhold] = 0
722
+ y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
723
+ y[image_depth < self.bg_threhold] = 0
724
+ z = np.ones_like(x) * np.pi * 2.0
725
+ image = np.stack([x, y, z], axis=2)
726
+ image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
727
+ image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
728
+ image = Image.fromarray(image)
729
+ image = image.resize(original_size)
730
+ updated_image_path = get_new_image_name(inputs, func_name="normal-map")
731
+ image.save(updated_image_path)
732
+ print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}")
733
+ return updated_image_path
734
+
735
+
736
+ class NormalText2Image:
737
+ def __init__(self, device):
738
+ print(f"Initializing NormalText2Image to {device}")
739
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
740
+ self.controlnet = ControlNetModel.from_pretrained(
741
+ "fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype)
742
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
743
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
744
+ torch_dtype=self.torch_dtype)
745
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
746
+ self.pipe.to(device)
747
+ self.seed = -1
748
+ self.a_prompt = 'best quality, extremely detailed'
749
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
750
+ ' fewer digits, cropped, worst quality, low quality'
751
+
752
+ @prompts(name="Generate Image Condition On Normal Map",
753
+ description="useful when you want to generate a new real image from both the user description and normal map. "
754
+ "like: generate a real image of a object or something from this normal map, "
755
+ "or generate a new real image of a object or something from the normal map. "
756
+ "The input to this tool should be a comma separated string of two, "
757
+ "representing the image_path and the user description")
758
+ def inference(self, inputs):
759
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
760
+ image = Image.open(image_path)
761
+ self.seed = random.randint(0, 65535)
762
+ seed_everything(self.seed)
763
+ prompt = f'{instruct_text}, {self.a_prompt}'
764
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
765
+ guidance_scale=9.0).images[0]
766
+ updated_image_path = get_new_image_name(image_path, func_name="normal2image")
767
+ image.save(updated_image_path)
768
+ print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, "
769
+ f"Output Image: {updated_image_path}")
770
+ return updated_image_path
771
+
772
+
773
+ class VisualQuestionAnswering:
774
+ def __init__(self, device):
775
+ print(f"Initializing VisualQuestionAnswering to {device}")
776
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
777
+ self.device = device
778
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
779
+ self.model = BlipForQuestionAnswering.from_pretrained(
780
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
781
+
782
+ @prompts(name="Answer Question About The Image",
783
+ description="useful when you need an answer for a question based on an image. "
784
+ "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
785
+ "The input to this tool should be a comma separated string of two, representing the image_path and the question")
786
+ def inference(self, inputs):
787
+ image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
788
+ raw_image = Image.open(image_path).convert('RGB')
789
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
790
+ out = self.model.generate(**inputs)
791
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
792
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
793
+ f"Output Answer: {answer}")
794
+ return answer
795
+
796
+ class InfinityOutPainting:
797
+ template_model = True # Add this line to show this is a template model.
798
+ def __init__(self, ImageCaptioning, ImageEditing, VisualQuestionAnswering):
799
+ # self.llm = OpenAI(temperature=0)
800
+ self.ImageCaption = ImageCaptioning
801
+ self.ImageEditing = ImageEditing
802
+ self.ImageVQA = VisualQuestionAnswering
803
+ self.a_prompt = 'best quality, extremely detailed'
804
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
805
+ 'fewer digits, cropped, worst quality, low quality'
806
+
807
+ def get_BLIP_vqa(self, image, question):
808
+ inputs = self.ImageVQA.processor(image, question, return_tensors="pt").to(self.ImageVQA.device,
809
+ self.ImageVQA.torch_dtype)
810
+ out = self.ImageVQA.model.generate(**inputs)
811
+ answer = self.ImageVQA.processor.decode(out[0], skip_special_tokens=True)
812
+ print(f"\nProcessed VisualQuestionAnswering, Input Question: {question}, Output Answer: {answer}")
813
+ return answer
814
+
815
+ def get_BLIP_caption(self, image):
816
+ inputs = self.ImageCaption.processor(image, return_tensors="pt").to(self.ImageCaption.device,
817
+ self.ImageCaption.torch_dtype)
818
+ out = self.ImageCaption.model.generate(**inputs)
819
+ BLIP_caption = self.ImageCaption.processor.decode(out[0], skip_special_tokens=True)
820
+ return BLIP_caption
821
+
822
+ # def check_prompt(self, prompt):
823
+ # check = f"Here is a paragraph with adjectives. " \
824
+ # f"{prompt} " \
825
+ # f"Please change all plural forms in the adjectives to singular forms. "
826
+ # return self.llm(check)
827
+
828
+ def get_imagine_caption(self, image, imagine):
829
+ BLIP_caption = self.get_BLIP_caption(image)
830
+ background_color = self.get_BLIP_vqa(image, 'what is the background color of this image')
831
+ style = self.get_BLIP_vqa(image, 'what is the style of this image')
832
+ imagine_prompt = f"let's pretend you are an excellent painter and now " \
833
+ f"there is an incomplete painting with {BLIP_caption} in the center, " \
834
+ f"please imagine the complete painting and describe it" \
835
+ f"you should consider the background color is {background_color}, the style is {style}" \
836
+ f"You should make the painting as vivid and realistic as possible" \
837
+ f"You can not use words like painting or picture" \
838
+ f"and you should use no more than 50 words to describe it"
839
+ # caption = self.llm(imagine_prompt) if imagine else BLIP_caption
840
+ caption = BLIP_caption
841
+ # caption = self.check_prompt(caption)
842
+ print(f'BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}') if imagine else print(
843
+ f'Prompt: {caption}')
844
+ return caption
845
+
846
+ def resize_image(self, image, max_size=100000, multiple=8):
847
+ aspect_ratio = image.size[0] / image.size[1]
848
+ new_width = int(math.sqrt(max_size * aspect_ratio))
849
+ new_height = int(new_width / aspect_ratio)
850
+ new_width, new_height = new_width - (new_width % multiple), new_height - (new_height % multiple)
851
+ return image.resize((new_width, new_height))
852
+
853
+ def dowhile(self, original_img, tosize, expand_ratio, imagine, usr_prompt):
854
+ old_img = original_img
855
+ while (old_img.size != tosize):
856
+ prompt = self.check_prompt(usr_prompt) if usr_prompt else self.get_imagine_caption(old_img, imagine)
857
+ crop_w = 15 if old_img.size[0] != tosize[0] else 0
858
+ crop_h = 15 if old_img.size[1] != tosize[1] else 0
859
+ old_img = ImageOps.crop(old_img, (crop_w, crop_h, crop_w, crop_h))
860
+ temp_canvas_size = (expand_ratio * old_img.width if expand_ratio * old_img.width < tosize[0] else tosize[0],
861
+ expand_ratio * old_img.height if expand_ratio * old_img.height < tosize[1] else tosize[
862
+ 1])
863
+ temp_canvas, temp_mask = Image.new("RGB", temp_canvas_size, color="white"), Image.new("L", temp_canvas_size,
864
+ color="white")
865
+ x, y = (temp_canvas.width - old_img.width) // 2, (temp_canvas.height - old_img.height) // 2
866
+ temp_canvas.paste(old_img, (x, y))
867
+ temp_mask.paste(0, (x, y, x + old_img.width, y + old_img.height))
868
+ resized_temp_canvas, resized_temp_mask = self.resize_image(temp_canvas), self.resize_image(temp_mask)
869
+ image = self.ImageEditing.inpaint(prompt=prompt, image=resized_temp_canvas, mask_image=resized_temp_mask,
870
+ height=resized_temp_canvas.height, width=resized_temp_canvas.width,
871
+ num_inference_steps=50).images[0].resize(
872
+ (temp_canvas.width, temp_canvas.height), Image.ANTIALIAS)
873
+ image = blend_gt2pt(old_img, image)
874
+ old_img = image
875
+ return old_img
876
+
877
+ @prompts(name="Extend An Image",
878
+ description="useful when you need to extend an image into a larger image."
879
+ "like: extend the image into a resolution of 2048x1024, extend the image into 2048x1024. "
880
+ "The input to this tool should be a comma separated string of two, representing the image_path and the resolution of widthxheight")
881
+ def inference(self, inputs):
882
+ image_path, resolution = inputs.split(',')
883
+ width, height = resolution.split('x')
884
+ tosize = (int(width), int(height))
885
+ image = Image.open(image_path)
886
+ image = ImageOps.crop(image, (10, 10, 10, 10))
887
+ out_painted_image = self.dowhile(image, tosize, 4, True, False)
888
+ updated_image_path = get_new_image_name(image_path, func_name="outpainting")
889
+ out_painted_image.save(updated_image_path)
890
+ print(f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input Resolution: {resolution}, "
891
+ f"Output Image: {updated_image_path}")
892
+ return updated_image_path