keldrenloy commited on
Commit
acfcb3a
·
1 Parent(s): 8aa0e27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -80,7 +80,6 @@ def unnormalize_box(bbox, width, height):
80
  ]
81
 
82
  def predict(image):
83
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
  model = LayoutLMv3ForTokenClassification.from_pretrained("keldrenloy/layoutlmv3cordfinetuned").to(device) #add your model directory here
85
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
86
  label_list,id2label,label2id, num_labels = convert_l2n_n2l(dataset)
@@ -89,7 +88,7 @@ def predict(image):
89
  encoding_inputs = processor(image,return_offsets_mapping=True, return_tensors="pt",truncation = True)
90
  offset_mapping = encoding_inputs.pop('offset_mapping')
91
  for k,v in encoding_inputs.items():
92
- encoding_inputs[k] = v.to(device)
93
 
94
  with torch.no_grad():
95
  outputs = model(**encoding_inputs)
@@ -101,14 +100,6 @@ def predict(image):
101
  true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
102
  true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
103
 
104
- return true_boxes, true_predictions
105
-
106
- def text_extraction(image):
107
- feature_extractor = LayoutLMv3FeatureExtractor()
108
- encoding = feature_extractor(image, return_tensors="pt")
109
- return encoding['words'][0]
110
-
111
- def image_render(image):
112
  draw = ImageDraw.Draw(image)
113
  font = ImageFont.load_default()
114
  true_boxes,true_predictions = predict(image)
@@ -122,10 +113,15 @@ def image_render(image):
122
  extracted_words = convert_results(words,true_predictions)
123
 
124
  return image,extracted_words
 
 
 
 
 
125
 
126
  css = """.output_image, .input_image {height: 600px !important}"""
127
 
128
- demo = gr.Interface(fn = image_render,
129
  inputs = gr.inputs.Image(type="pil"),
130
  outputs = [gr.outputs.Image(type="pil", label="annotated image"),'text'],
131
  css = css,
@@ -136,4 +132,4 @@ demo = gr.Interface(fn = image_render,
136
  flagging_dir = "flagged",
137
  analytics_enabled = True, enable_queue=True
138
  )
139
- demo.launch(inline=False, share=False, debug=False)
 
80
  ]
81
 
82
  def predict(image):
 
83
  model = LayoutLMv3ForTokenClassification.from_pretrained("keldrenloy/layoutlmv3cordfinetuned").to(device) #add your model directory here
84
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
85
  label_list,id2label,label2id, num_labels = convert_l2n_n2l(dataset)
 
88
  encoding_inputs = processor(image,return_offsets_mapping=True, return_tensors="pt",truncation = True)
89
  offset_mapping = encoding_inputs.pop('offset_mapping')
90
  for k,v in encoding_inputs.items():
91
+ encoding_inputs[k] = v
92
 
93
  with torch.no_grad():
94
  outputs = model(**encoding_inputs)
 
100
  true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
101
  true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
102
 
 
 
 
 
 
 
 
 
103
  draw = ImageDraw.Draw(image)
104
  font = ImageFont.load_default()
105
  true_boxes,true_predictions = predict(image)
 
113
  extracted_words = convert_results(words,true_predictions)
114
 
115
  return image,extracted_words
116
+
117
+ def text_extraction(image):
118
+ feature_extractor = LayoutLMv3FeatureExtractor()
119
+ encoding = feature_extractor(image, return_tensors="pt")
120
+ return encoding['words'][0]
121
 
122
  css = """.output_image, .input_image {height: 600px !important}"""
123
 
124
+ demo = gr.Interface(fn = predict,
125
  inputs = gr.inputs.Image(type="pil"),
126
  outputs = [gr.outputs.Image(type="pil", label="annotated image"),'text'],
127
  css = css,
 
132
  flagging_dir = "flagged",
133
  analytics_enabled = True, enable_queue=True
134
  )
135
+ demo.launch(debug=False)