Update inference.py
Browse files- inference.py +30 -13
inference.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
import torch
|
2 |
from PIL import Image
|
3 |
from conversation import conv_templates
|
4 |
-
from builder import load_pretrained_model
|
5 |
from functools import partial
|
6 |
import numpy as np
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# define the task categories
|
9 |
box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
|
@@ -36,20 +40,20 @@ def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
|
|
36 |
if mask is not None:
|
37 |
coor_mask = coor_mask * mask
|
38 |
|
39 |
-
#
|
40 |
coor_mask = torch.from_numpy(coor_mask)
|
41 |
assert len(coor_mask.nonzero()) != 0, "Generated mask is empty :("
|
42 |
|
|
|
43 |
return coor_mask
|
44 |
|
45 |
-
|
46 |
-
def infer_single_prompt(image_path, prompt, model_path, region=None, model_name="ferret_llama", conv_mode="ferret_llama_3"):
|
47 |
img = Image.open(image_path).convert('RGB')
|
48 |
|
49 |
# this loads the model, image processor and tokenizer
|
50 |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
|
51 |
-
|
52 |
-
# define the image size (e.g., 224x224 or 336x336)
|
53 |
image_size = {"height": 336, "width": 336}
|
54 |
|
55 |
# process the image
|
@@ -68,17 +72,27 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
|
|
68 |
conv.append_message(conv.roles[0], prompt)
|
69 |
conv.append_message(conv.roles[1], None)
|
70 |
prompt_input = conv.get_prompt()
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
|
74 |
|
75 |
# region mask logic (if region is provided)
|
76 |
region_masks = None
|
77 |
-
if region is not None:
|
78 |
raw_w, raw_h = img.size
|
79 |
region_masks = generate_mask_for_feature(region, raw_w, raw_h).unsqueeze(0).cuda().half()
|
80 |
-
region_masks = [[
|
|
|
|
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
82 |
# generate model output
|
83 |
with torch.inference_mode():
|
84 |
# Use region_masks in model's forward call
|
@@ -87,9 +101,11 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
|
|
87 |
model.orig_forward,
|
88 |
region_masks=region_masks
|
89 |
)
|
|
|
90 |
output_ids = model.generate(
|
91 |
input_ids,
|
92 |
images=image_tensor,
|
|
|
93 |
max_new_tokens=1024,
|
94 |
num_beams=1,
|
95 |
region_masks=region_masks, # pass the region mask to the model
|
@@ -102,7 +118,8 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
|
|
102 |
return output_text.strip()
|
103 |
|
104 |
# We also define a task-specific inference function
|
105 |
-
def infer_ui_task(image_path, prompt, model_path, task, region=None):
|
|
|
106 |
"""
|
107 |
Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
|
108 |
"""
|
@@ -111,7 +128,7 @@ def infer_ui_task(image_path, prompt, model_path, task, region=None):
|
|
111 |
|
112 |
if task in box_in_tasks:
|
113 |
print(f"Processing {task} with bounding box region.")
|
114 |
-
return infer_single_prompt(image_path, prompt, model_path, region)
|
115 |
|
116 |
elif task in box_out_tasks:
|
117 |
print(f"Processing {task} without bounding box region.")
|
@@ -122,4 +139,4 @@ def infer_ui_task(image_path, prompt, model_path, task, region=None):
|
|
122 |
return infer_single_prompt(image_path, prompt, model_path)
|
123 |
|
124 |
else:
|
125 |
-
raise ValueError(f"Unknown task type: {task}")
|
|
|
1 |
import torch
|
2 |
from PIL import Image
|
3 |
from conversation import conv_templates
|
4 |
+
from builder import load_pretrained_model # Assuming this is your custom model loader
|
5 |
from functools import partial
|
6 |
import numpy as np
|
7 |
+
DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
|
8 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
9 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
10 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
11 |
|
12 |
# define the task categories
|
13 |
box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
|
|
|
40 |
if mask is not None:
|
41 |
coor_mask = coor_mask * mask
|
42 |
|
43 |
+
# convert to torch tensor and ensure it contains non-zero values
|
44 |
coor_mask = torch.from_numpy(coor_mask)
|
45 |
assert len(coor_mask.nonzero()) != 0, "Generated mask is empty :("
|
46 |
|
47 |
+
|
48 |
return coor_mask
|
49 |
|
50 |
+
|
51 |
+
def infer_single_prompt(image_path, prompt, model_path, region=None, model_name="ferret_llama", conv_mode="ferret_llama_3", add_region_feature=False):
|
52 |
img = Image.open(image_path).convert('RGB')
|
53 |
|
54 |
# this loads the model, image processor and tokenizer
|
55 |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
|
56 |
+
# define the image size required by clip
|
|
|
57 |
image_size = {"height": 336, "width": 336}
|
58 |
|
59 |
# process the image
|
|
|
72 |
conv.append_message(conv.roles[0], prompt)
|
73 |
conv.append_message(conv.roles[1], None)
|
74 |
prompt_input = conv.get_prompt()
|
75 |
+
|
76 |
+
# add the special tokens
|
77 |
+
prompt_input = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_input
|
78 |
|
79 |
+
|
|
|
80 |
|
81 |
# region mask logic (if region is provided)
|
82 |
region_masks = None
|
83 |
+
if add_region_feature and region is not None:
|
84 |
raw_w, raw_h = img.size
|
85 |
region_masks = generate_mask_for_feature(region, raw_w, raw_h).unsqueeze(0).cuda().half()
|
86 |
+
region_masks = [[region_mask_i.cuda().half() for region_mask_i in region_masks]]
|
87 |
+
prompt_input = prompt_input.replace("<bbox_location0>", f"[{region[0]}, {region[1]}, {region[2]}, {region[3]}] {DEFAULT_REGION_FEA_TOKEN}")
|
88 |
+
|
89 |
+
# tokenize prompt
|
90 |
+
# input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
|
91 |
|
92 |
+
inputs = tokenizer(prompt_input, return_tensors='pt', padding=True)
|
93 |
+
input_ids = inputs['input_ids'].cuda()
|
94 |
+
attention_mask = inputs['attention_mask'].cuda()
|
95 |
+
|
96 |
# generate model output
|
97 |
with torch.inference_mode():
|
98 |
# Use region_masks in model's forward call
|
|
|
101 |
model.orig_forward,
|
102 |
region_masks=region_masks
|
103 |
)
|
104 |
+
# explcit add of attention mask
|
105 |
output_ids = model.generate(
|
106 |
input_ids,
|
107 |
images=image_tensor,
|
108 |
+
attention_mask=attention_mask,
|
109 |
max_new_tokens=1024,
|
110 |
num_beams=1,
|
111 |
region_masks=region_masks, # pass the region mask to the model
|
|
|
118 |
return output_text.strip()
|
119 |
|
120 |
# We also define a task-specific inference function
|
121 |
+
def infer_ui_task(image_path, prompt, model_path, task, region=None, add_region_feature=False):
|
122 |
+
# region = torch.tensor(region).cuda()
|
123 |
"""
|
124 |
Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
|
125 |
"""
|
|
|
128 |
|
129 |
if task in box_in_tasks:
|
130 |
print(f"Processing {task} with bounding box region.")
|
131 |
+
return infer_single_prompt(image_path, prompt, model_path, region, add_region_feature=add_region_feature)
|
132 |
|
133 |
elif task in box_out_tasks:
|
134 |
print(f"Processing {task} without bounding box region.")
|
|
|
139 |
return infer_single_prompt(image_path, prompt, model_path)
|
140 |
|
141 |
else:
|
142 |
+
raise ValueError(f"Unknown task type: {task}")
|