TongkunGuan commited on
Commit
75b4642
·
verified ·
1 Parent(s): 3d2b840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -115,22 +115,22 @@ def load_model(check_type):
115
  def process_image(model, tokenizer, transform, device, check_type, image, text):
116
  global current_vis, current_bpe, current_index
117
  src_size = image.size
118
- # Ensure all processing is done on the correct device
119
- image = image.to(device)
120
 
 
121
  if 'TokenOCR' in check_type:
 
122
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
123
  image_size=model.config.force_image_size,
124
  use_thumbnail=model.config.use_thumbnail,
125
  return_ratio=True)
126
- pixel_values = torch.stack([transform(img) for img in images]).to(device)
127
  else:
128
- pixel_values = torch.stack([transform(image)]).to(device)
 
129
  target_ratio = (1, 1)
130
 
131
  text += ' '
132
- input_ids = tokenizer(text)['input_ids'][1:]
133
- input_ids = torch.tensor(input_ids, device=device)
134
 
135
  with torch.no_grad():
136
  if 'R50' in check_type:
@@ -147,14 +147,14 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
147
  resized_size = size1 if size1 is not None else size2
148
 
149
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
150
- all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
151
- current_vis = generate_similiarity_map([image.cpu()], attn_map.cpu(),
152
- [tokenizer.decode([i]) for i in input_ids],
153
  [], target_ratio, src_size)
154
 
155
- current_bpe = [tokenizer.decode([i]) for i in input_ids]
156
  current_bpe[-1] = text
157
- return image.cpu(), current_vis[0], current_bpe[0]
 
158
 
159
  # 事件处理函数
160
  def update_index(change):
 
115
  def process_image(model, tokenizer, transform, device, check_type, image, text):
116
  global current_vis, current_bpe, current_index
117
  src_size = image.size
 
 
118
 
119
+ # Convert PIL Image to Tensor and move to the appropriate device
120
  if 'TokenOCR' in check_type:
121
+ # If dynamic preprocessing is required, handle differently
122
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
123
  image_size=model.config.force_image_size,
124
  use_thumbnail=model.config.use_thumbnail,
125
  return_ratio=True)
126
+ pixel_values = torch.stack([transform(img).to(device) for img in images])
127
  else:
128
+ # Standard image processing for a single image
129
+ pixel_values = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
130
  target_ratio = (1, 1)
131
 
132
  text += ' '
133
+ input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device) # Ensure tokens are on the same device
 
134
 
135
  with torch.no_grad():
136
  if 'R50' in check_type:
 
147
  resized_size = size1 if size1 is not None else size2
148
 
149
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
150
+ current_vis = generate_similiarity_map([image], attn_map,
151
+ [tokenizer.decode([i]) for i in input_ids.squeeze()],
 
152
  [], target_ratio, src_size)
153
 
154
+ current_bpe = [tokenizer.decode([i]) for i in input_ids.squeeze()]
155
  current_bpe[-1] = text
156
+ return image, current_vis[0], current_bpe[0]
157
+
158
 
159
  # 事件处理函数
160
  def update_index(change):