Zevin2023 commited on
Commit
d29aeb2
·
verified ·
1 Parent(s): a11427d

Update demo/seagull_inference.py

Browse files
Files changed (1) hide show
  1. demo/seagull_inference.py +25 -12
demo/seagull_inference.py CHANGED
@@ -13,6 +13,7 @@ import numpy as np
13
  import cv2
14
  from typing import List
15
  from PIL import Image
 
16
 
17
  class Seagull():
18
  def __init__(self, model_path, device='cuda'):
@@ -40,9 +41,9 @@ class Seagull():
40
  begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
41
 
42
  instruction = {
43
- 'distortion analysis': 'Provide the distortion type of this region.',
44
- 'quality score': 'Analyze the quality of this region.',
45
- 'importance score': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
46
  }
47
 
48
  self.ids_input = {}
@@ -70,7 +71,7 @@ class Seagull():
70
  else:
71
  preprocessed_img = img.copy()
72
 
73
- return (preprocessed_img, preprocessed_img, preprocessed_img)
74
 
75
  def preprocess(self, img):
76
  image = self.image_processor.preprocess(img,
@@ -83,19 +84,31 @@ class Seagull():
83
  align_corners=False).squeeze(0)
84
 
85
  return image
86
-
87
- def seagull_predict(self, img, mask, instruct_type):
88
- image = self.preprocess(img)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  mask = np.array(mask, dtype=np.int)
 
91
  ys, xs = np.where(mask > 0)
92
  if len(xs) > 0 and len(ys) > 0:
93
- # Find the minimal bounding rectangle for the entire mask
94
  x_min, x_max = np.min(xs), np.max(xs)
95
  y_min, y_max = np.min(ys), np.max(ys)
96
  w1 = x_max - x_min
97
  h1 = y_max - y_min
98
-
99
  bounding_box = (x_min, y_min, w1, h1)
100
  else:
101
  bounding_box = None
@@ -104,7 +117,7 @@ class Seagull():
104
  mask = np.array(mask > 0.1, dtype=np.uint8)
105
  masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
106
 
107
- input_ids = self.ids_input[instruct_type.lower()]
108
 
109
  x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
110
  cropped_img = img[y1:y1 + h1, x1:x1 + w1]
@@ -127,8 +140,8 @@ class Seagull():
127
  max_new_tokens=2048,
128
  use_cache=True,
129
  num_beams=1,
130
- top_k = 0, # 不进行topk
131
- top_p = 1, # 累计概率为
132
  )
133
 
134
  self.model.forward = self.model.orig_forward
 
13
  import cv2
14
  from typing import List
15
  from PIL import Image
16
+ from pycocotools import mask as mask_utils
17
 
18
  class Seagull():
19
  def __init__(self, model_path, device='cuda'):
 
41
  begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
42
 
43
  instruction = {
44
+ 'distortion': 'Provide the distortion type of this region.',
45
+ 'quality': 'Analyze the quality of this region.',
46
+ 'importance': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
47
  }
48
 
49
  self.ids_input = {}
 
71
  else:
72
  preprocessed_img = img.copy()
73
 
74
+ return (preprocessed_img, preprocessed_img, preprocessed_img, preprocessed_img)
75
 
76
  def preprocess(self, img):
77
  image = self.image_processor.preprocess(img,
 
84
  align_corners=False).squeeze(0)
85
 
86
  return image
 
 
 
87
 
88
+ def seagull_predict(self, img, mask, instruct_type, mask_type='rle'):
89
+ if isinstance(img, str):
90
+ img = cv2.imread(img)
91
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
92
+ h, w, _ = img.shape
93
+
94
+ if mask_type == 'rle': # use the mask to indicate the roi
95
+ compressed_rle = {'size' : [h, w], 'counts' : mask}
96
+ mask = mask_utils.decode(compressed_rle)
97
+ elif mask_type == 'points': # use the point to indicate the roi
98
+ x_min, y_min, w1, h1 = mask
99
+ x_max, y_max = x_min + w1, y_min + h1
100
+ mask = np.zeros_like(img[:, :, 0])
101
+ mask[max(0, y_min):min(y_max, mask.shape[0]), max(0, x_min):min(x_max, mask.shape[1])] = 1
102
+
103
+ image = self.preprocess(img)
104
  mask = np.array(mask, dtype=np.int)
105
+
106
  ys, xs = np.where(mask > 0)
107
  if len(xs) > 0 and len(ys) > 0:
 
108
  x_min, x_max = np.min(xs), np.max(xs)
109
  y_min, y_max = np.min(ys), np.max(ys)
110
  w1 = x_max - x_min
111
  h1 = y_max - y_min
 
112
  bounding_box = (x_min, y_min, w1, h1)
113
  else:
114
  bounding_box = None
 
117
  mask = np.array(mask > 0.1, dtype=np.uint8)
118
  masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
119
 
120
+ input_ids = self.ids_input[instruct_type.split()[0].lower()]
121
 
122
  x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
123
  cropped_img = img[y1:y1 + h1, x1:x1 + w1]
 
140
  max_new_tokens=2048,
141
  use_cache=True,
142
  num_beams=1,
143
+ top_k = 0,
144
+ top_p = 1,
145
  )
146
 
147
  self.model.forward = self.model.orig_forward