Update demo/seagull_inference.py
Browse files- demo/seagull_inference.py +25 -12
demo/seagull_inference.py
CHANGED
@@ -13,6 +13,7 @@ import numpy as np
|
|
13 |
import cv2
|
14 |
from typing import List
|
15 |
from PIL import Image
|
|
|
16 |
|
17 |
class Seagull():
|
18 |
def __init__(self, model_path, device='cuda'):
|
@@ -40,9 +41,9 @@ class Seagull():
|
|
40 |
begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
|
41 |
|
42 |
instruction = {
|
43 |
-
'distortion
|
44 |
-
'quality
|
45 |
-
'importance
|
46 |
}
|
47 |
|
48 |
self.ids_input = {}
|
@@ -70,7 +71,7 @@ class Seagull():
|
|
70 |
else:
|
71 |
preprocessed_img = img.copy()
|
72 |
|
73 |
-
return (preprocessed_img, preprocessed_img, preprocessed_img)
|
74 |
|
75 |
def preprocess(self, img):
|
76 |
image = self.image_processor.preprocess(img,
|
@@ -83,19 +84,31 @@ class Seagull():
|
|
83 |
align_corners=False).squeeze(0)
|
84 |
|
85 |
return image
|
86 |
-
|
87 |
-
def seagull_predict(self, img, mask, instruct_type):
|
88 |
-
image = self.preprocess(img)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
mask = np.array(mask, dtype=np.int)
|
|
|
91 |
ys, xs = np.where(mask > 0)
|
92 |
if len(xs) > 0 and len(ys) > 0:
|
93 |
-
# Find the minimal bounding rectangle for the entire mask
|
94 |
x_min, x_max = np.min(xs), np.max(xs)
|
95 |
y_min, y_max = np.min(ys), np.max(ys)
|
96 |
w1 = x_max - x_min
|
97 |
h1 = y_max - y_min
|
98 |
-
|
99 |
bounding_box = (x_min, y_min, w1, h1)
|
100 |
else:
|
101 |
bounding_box = None
|
@@ -104,7 +117,7 @@ class Seagull():
|
|
104 |
mask = np.array(mask > 0.1, dtype=np.uint8)
|
105 |
masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
|
106 |
|
107 |
-
input_ids = self.ids_input[instruct_type.lower()]
|
108 |
|
109 |
x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
|
110 |
cropped_img = img[y1:y1 + h1, x1:x1 + w1]
|
@@ -127,8 +140,8 @@ class Seagull():
|
|
127 |
max_new_tokens=2048,
|
128 |
use_cache=True,
|
129 |
num_beams=1,
|
130 |
-
top_k = 0,
|
131 |
-
top_p = 1,
|
132 |
)
|
133 |
|
134 |
self.model.forward = self.model.orig_forward
|
|
|
13 |
import cv2
|
14 |
from typing import List
|
15 |
from PIL import Image
|
16 |
+
from pycocotools import mask as mask_utils
|
17 |
|
18 |
class Seagull():
|
19 |
def __init__(self, model_path, device='cuda'):
|
|
|
41 |
begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
|
42 |
|
43 |
instruction = {
|
44 |
+
'distortion': 'Provide the distortion type of this region.',
|
45 |
+
'quality': 'Analyze the quality of this region.',
|
46 |
+
'importance': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
|
47 |
}
|
48 |
|
49 |
self.ids_input = {}
|
|
|
71 |
else:
|
72 |
preprocessed_img = img.copy()
|
73 |
|
74 |
+
return (preprocessed_img, preprocessed_img, preprocessed_img, preprocessed_img)
|
75 |
|
76 |
def preprocess(self, img):
|
77 |
image = self.image_processor.preprocess(img,
|
|
|
84 |
align_corners=False).squeeze(0)
|
85 |
|
86 |
return image
|
|
|
|
|
|
|
87 |
|
88 |
+
def seagull_predict(self, img, mask, instruct_type, mask_type='rle'):
|
89 |
+
if isinstance(img, str):
|
90 |
+
img = cv2.imread(img)
|
91 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
92 |
+
h, w, _ = img.shape
|
93 |
+
|
94 |
+
if mask_type == 'rle': # use the mask to indicate the roi
|
95 |
+
compressed_rle = {'size' : [h, w], 'counts' : mask}
|
96 |
+
mask = mask_utils.decode(compressed_rle)
|
97 |
+
elif mask_type == 'points': # use the point to indicate the roi
|
98 |
+
x_min, y_min, w1, h1 = mask
|
99 |
+
x_max, y_max = x_min + w1, y_min + h1
|
100 |
+
mask = np.zeros_like(img[:, :, 0])
|
101 |
+
mask[max(0, y_min):min(y_max, mask.shape[0]), max(0, x_min):min(x_max, mask.shape[1])] = 1
|
102 |
+
|
103 |
+
image = self.preprocess(img)
|
104 |
mask = np.array(mask, dtype=np.int)
|
105 |
+
|
106 |
ys, xs = np.where(mask > 0)
|
107 |
if len(xs) > 0 and len(ys) > 0:
|
|
|
108 |
x_min, x_max = np.min(xs), np.max(xs)
|
109 |
y_min, y_max = np.min(ys), np.max(ys)
|
110 |
w1 = x_max - x_min
|
111 |
h1 = y_max - y_min
|
|
|
112 |
bounding_box = (x_min, y_min, w1, h1)
|
113 |
else:
|
114 |
bounding_box = None
|
|
|
117 |
mask = np.array(mask > 0.1, dtype=np.uint8)
|
118 |
masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
|
119 |
|
120 |
+
input_ids = self.ids_input[instruct_type.split()[0].lower()]
|
121 |
|
122 |
x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
|
123 |
cropped_img = img[y1:y1 + h1, x1:x1 + w1]
|
|
|
140 |
max_new_tokens=2048,
|
141 |
use_cache=True,
|
142 |
num_beams=1,
|
143 |
+
top_k = 0,
|
144 |
+
top_p = 1,
|
145 |
)
|
146 |
|
147 |
self.model.forward = self.model.orig_forward
|