Scezui commited on
Commit
c7d90d2
·
1 Parent(s): ece6d05

Fixed the inaccurate prediction

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  from flask import Flask, request, render_template, jsonify, send_file, redirect, url_for, flash, send_from_directory, session
2
  from PIL import Image, ImageDraw
3
  import torch
@@ -29,11 +36,7 @@ from Layoutlmv3_inference.ocr import prepare_batch_for_inference
29
  from Layoutlmv3_inference.inference_handler import handle
30
  import logging
31
  import os
32
- import warnings
33
-
34
- # Ignore SourceChangeWarning
35
- warnings.filterwarnings("ignore", category=DeprecationWarning)
36
- # warnings.filterwarnings("ignore", category=SourceChangeWarning)
37
 
38
 
39
  UPLOAD_FOLDER = 'static/temp/uploads'
@@ -104,8 +107,8 @@ def make_predictions(image_paths):
104
  temp = None
105
  try:
106
  # For Windows OS
107
- # temp = pathlib.PosixPath # Save the original state
108
- # pathlib.PosixPath = pathlib.WindowsPath # Change to WindowsPath temporarily
109
 
110
  model_path = Path(r'model/export')
111
  learner = load_learner(model_path)
@@ -129,10 +132,10 @@ def make_predictions(image_paths):
129
  except Exception as e:
130
  return {"error in make_predictions": str(e)}
131
 
132
- # finally:
133
- # pathlib.PosixPath = temp
 
134
 
135
- import copy
136
  @app.route('/predict/<filenames>', methods=['GET', 'POST'])
137
  def predict_files(filenames):
138
  prediction_results = []
@@ -162,17 +165,19 @@ def predict_files(filenames):
162
  if os.path.exists(file_path):
163
  # Call make_predictions automatically
164
  prediction_result = make_predictions([file_path]) # Pass file_path as a list
165
- prediction_results.extend(prediction_result)
166
-
167
- # Create a copy of prediction_results before deletion
168
  prediction_results_copy = copy.deepcopy(prediction_results)
169
-
170
- if prediction_result[0] != 'non-receipt': # Check if prediction is not 'non-receipt'
171
- prediction_results.extend(prediction_result) # Use extend to add elements of list to another list
172
- else:
173
- # Delete the image if it's predicted as non-receipt
174
- os.remove(file_path)
175
- print(image_paths)
 
 
 
 
176
 
177
  return render_template('extractor.html', image_paths=image_paths, prediction_results = prediction_results, predictions=dict(zip(image_paths, prediction_results_copy)))
178
 
 
1
+ # import warnings
2
+
3
+ # # Ignore SourceChangeWarning
4
+ # warnings.filterwarnings("ignore", category=DeprecationWarning)
5
+ # warnings.filterwarnings("ignore", category=SourceChangeWarning)
6
+
7
+ # Dependencies
8
  from flask import Flask, request, render_template, jsonify, send_file, redirect, url_for, flash, send_from_directory, session
9
  from PIL import Image, ImageDraw
10
  import torch
 
36
  from Layoutlmv3_inference.inference_handler import handle
37
  import logging
38
  import os
39
+ import copy
 
 
 
 
40
 
41
 
42
  UPLOAD_FOLDER = 'static/temp/uploads'
 
107
  temp = None
108
  try:
109
  # For Windows OS
110
+ temp = pathlib.PosixPath # Save the original state
111
+ pathlib.PosixPath = pathlib.WindowsPath # Change to WindowsPath temporarily
112
 
113
  model_path = Path(r'model/export')
114
  learner = load_learner(model_path)
 
132
  except Exception as e:
133
  return {"error in make_predictions": str(e)}
134
 
135
+ finally:
136
+ pathlib.PosixPath = temp
137
+
138
 
 
139
  @app.route('/predict/<filenames>', methods=['GET', 'POST'])
140
  def predict_files(filenames):
141
  prediction_results = []
 
165
  if os.path.exists(file_path):
166
  # Call make_predictions automatically
167
  prediction_result = make_predictions([file_path]) # Pass file_path as a list
168
+ prediction_results.append(prediction_result[0]) # Append only the first prediction result
 
 
169
  prediction_results_copy = copy.deepcopy(prediction_results)
170
+
171
+ non_receipt_indices = []
172
+ for i, prediction in enumerate(prediction_results):
173
+ if prediction == 'non-receipt':
174
+ non_receipt_indices.append(i)
175
+
176
+ # Delete images in reverse order to avoid index shifting
177
+ for index in non_receipt_indices[::-1]:
178
+ file_to_remove = os.path.join('static', 'temp', 'uploads', image_paths[index])
179
+ if os.path.exists(file_to_remove):
180
+ os.remove(file_to_remove)
181
 
182
  return render_template('extractor.html', image_paths=image_paths, prediction_results = prediction_results, predictions=dict(zip(image_paths, prediction_results_copy)))
183