phamvi856 commited on
Commit
391f127
·
1 Parent(s): d406dd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -58,15 +58,17 @@ def process_image(image):
58
  input_ids = encoding.input_ids.to(device)
59
  attention_mask = encoding.attention_mask.to(device)
60
  bbox = encoding.bbox[0].tolist()
61
- bbox = torch.tensor(bbox, dtype=torch.float32).unsqueeze(0).to(device)
62
 
63
  # Inference
64
- outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
 
 
65
  predicted_labels = outputs.logits.argmax(dim=2).squeeze().tolist()
66
 
67
  # Extract content from boxes
68
  extracted_content = {}
69
- for idx, box in enumerate(bbox):
70
  predicted_label = id2label[predicted_labels[idx]]
71
  box_width = np.array(box)[2] - np.array(box)[0]
72
  box_height = np.array(box)[3] - np.array(box)[1]
@@ -76,7 +78,7 @@ def process_image(image):
76
  # Draw predictions over the image
77
  draw = ImageDraw.Draw(image)
78
  font = ImageFont.load_default()
79
- for prediction, box in zip(predicted_labels, bbox):
80
  predicted_label = iob_to_label(id2label[prediction])
81
  box_width = np.array(box)[2] - np.array(box)[0]
82
  box_height = np.array(box)[3] - np.array(box)[1]
@@ -108,3 +110,4 @@ iface = gr.Interface(fn=process_image,
108
  iface.launch(inline=False, share=False, debug=False)
109
 
110
 
 
 
58
  input_ids = encoding.input_ids.to(device)
59
  attention_mask = encoding.attention_mask.to(device)
60
  bbox = encoding.bbox[0].tolist()
61
+ bbox = torch.tensor(bbox, dtype=torch.long).unsqueeze(0).to(device)
62
 
63
  # Inference
64
+ with torch.no_grad():
65
+ outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
66
+
67
  predicted_labels = outputs.logits.argmax(dim=2).squeeze().tolist()
68
 
69
  # Extract content from boxes
70
  extracted_content = {}
71
+ for idx, box in enumerate(bbox[0]):
72
  predicted_label = id2label[predicted_labels[idx]]
73
  box_width = np.array(box)[2] - np.array(box)[0]
74
  box_height = np.array(box)[3] - np.array(box)[1]
 
78
  # Draw predictions over the image
79
  draw = ImageDraw.Draw(image)
80
  font = ImageFont.load_default()
81
+ for prediction, box in zip(predicted_labels, bbox[0]):
82
  predicted_label = iob_to_label(id2label[prediction])
83
  box_width = np.array(box)[2] - np.array(box)[0]
84
  box_height = np.array(box)[3] - np.array(box)[1]
 
110
  iface.launch(inline=False, share=False, debug=False)
111
 
112
 
113
+