X-iZhang commited on
Commit
e1b2b95
·
verified ·
1 Parent(s): 97a468f

Upload run_libra.py

Browse files
Files changed (1) hide show
  1. libra/eval/run_libra.py +29 -10
libra/eval/run_libra.py CHANGED
@@ -14,6 +14,21 @@ from io import BytesIO
14
  from pydicom.pixel_data_handlers.util import apply_voi_lut
15
  import datetime
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def load_images(image_file):
19
  """
@@ -77,7 +92,7 @@ def load_images(image_file):
77
 
78
  return image
79
 
80
- def get_image_tensors(image_path, image_processor, model, device='cpu'):
81
  # Load and preprocess the images
82
  if isinstance(image_path, str):
83
  image = []
@@ -118,19 +133,24 @@ def libra_eval(
118
  model_base=None,
119
  image_file=None,
120
  query=None,
121
- conv_mode="libra_v1",
122
  temperature=0.2,
123
  top_p=None,
124
  num_beams=1,
125
  num_return_sequences=None,
126
  length_penalty=1.0,
127
- max_new_tokens=128
 
128
  ):
129
  # Model
130
  disable_torch_init()
131
 
132
- model_name = get_model_name_from_path(model_path)
133
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
 
 
 
 
134
 
135
  qs = query
136
  if model.config.mm_use_im_start_end:
@@ -151,7 +171,7 @@ def libra_eval(
151
  conv.append_message(conv.roles[1], None)
152
  prompt = conv.get_prompt()
153
 
154
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to("cpu")
155
  attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
156
  pad_token_id = tokenizer.pad_token_id
157
 
@@ -162,7 +182,7 @@ def libra_eval(
162
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
163
 
164
  with torch.inference_mode():
165
-
166
  if num_beams > 1:
167
  output_ids = model.generate(
168
  input_ids=input_ids,
@@ -192,7 +212,7 @@ def libra_eval(
192
  pad_token_id=pad_token_id,
193
  stopping_criteria=[stopping_criteria],
194
  use_cache=True)
195
-
196
  input_token_len = input_ids.shape[1]
197
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
198
 
@@ -205,8 +225,7 @@ def libra_eval(
205
  if outputs.endswith(stop_str):
206
  outputs = outputs[:-len(stop_str)]
207
  outputs = outputs.strip()
208
-
209
- print("outputs",outputs)
210
  return outputs
211
 
212
  if __name__ == "__main__":
 
14
  from pydicom.pixel_data_handlers.util import apply_voi_lut
15
  import datetime
16
 
17
+ def load_model(model_path, model_base=None):
18
+ """
19
+ Load the model and return its components.
20
+
21
+ Args:
22
+ model_path (str): Path to the model.
23
+ model_base (str): Base model, if any.
24
+
25
+ Returns:
26
+ tuple: (tokenizer, model, image_processor, context_len)
27
+ """
28
+ disable_torch_init()
29
+ model_name = get_model_name_from_path(model_path)
30
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
31
+ return tokenizer, model, image_processor, context_len
32
 
33
  def load_images(image_file):
34
  """
 
92
 
93
  return image
94
 
95
+ def get_image_tensors(image_path, image_processor, model, device='cuda'):
96
  # Load and preprocess the images
97
  if isinstance(image_path, str):
98
  image = []
 
133
  model_base=None,
134
  image_file=None,
135
  query=None,
136
+ conv_mode=None,
137
  temperature=0.2,
138
  top_p=None,
139
  num_beams=1,
140
  num_return_sequences=None,
141
  length_penalty=1.0,
142
+ max_new_tokens=128,
143
+ libra_model=None
144
  ):
145
  # Model
146
  disable_torch_init()
147
 
148
+ if libra_model is not None:
149
+ tokenizer, model, image_processor, context_len = libra_model
150
+ model_name = model.config._name_or_path
151
+ else:
152
+ tokenizer, model, image_processor, context_len = load_model(model_path, model_base)
153
+ model_name = get_model_name_from_path(model_path)
154
 
155
  qs = query
156
  if model.config.mm_use_im_start_end:
 
171
  conv.append_message(conv.roles[1], None)
172
  prompt = conv.get_prompt()
173
 
174
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
175
  attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
176
  pad_token_id = tokenizer.pad_token_id
177
 
 
182
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
183
 
184
  with torch.inference_mode():
185
+ torch.cuda.empty_cache()
186
  if num_beams > 1:
187
  output_ids = model.generate(
188
  input_ids=input_ids,
 
212
  pad_token_id=pad_token_id,
213
  stopping_criteria=[stopping_criteria],
214
  use_cache=True)
215
+
216
  input_token_len = input_ids.shape[1]
217
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
218
 
 
225
  if outputs.endswith(stop_str):
226
  outputs = outputs[:-len(stop_str)]
227
  outputs = outputs.strip()
228
+
 
229
  return outputs
230
 
231
  if __name__ == "__main__":