Spaces:
Running
Running
Fixed the inaccurate prediction
Browse files
app.py
CHANGED
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from flask import Flask, request, render_template, jsonify, send_file, redirect, url_for, flash, send_from_directory, session
|
2 |
from PIL import Image, ImageDraw
|
3 |
import torch
|
@@ -29,11 +36,7 @@ from Layoutlmv3_inference.ocr import prepare_batch_for_inference
|
|
29 |
from Layoutlmv3_inference.inference_handler import handle
|
30 |
import logging
|
31 |
import os
|
32 |
-
import
|
33 |
-
|
34 |
-
# Ignore SourceChangeWarning
|
35 |
-
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
36 |
-
# warnings.filterwarnings("ignore", category=SourceChangeWarning)
|
37 |
|
38 |
|
39 |
UPLOAD_FOLDER = 'static/temp/uploads'
|
@@ -104,8 +107,8 @@ def make_predictions(image_paths):
|
|
104 |
temp = None
|
105 |
try:
|
106 |
# For Windows OS
|
107 |
-
|
108 |
-
|
109 |
|
110 |
model_path = Path(r'model/export')
|
111 |
learner = load_learner(model_path)
|
@@ -129,10 +132,10 @@ def make_predictions(image_paths):
|
|
129 |
except Exception as e:
|
130 |
return {"error in make_predictions": str(e)}
|
131 |
|
132 |
-
|
133 |
-
|
|
|
134 |
|
135 |
-
import copy
|
136 |
@app.route('/predict/<filenames>', methods=['GET', 'POST'])
|
137 |
def predict_files(filenames):
|
138 |
prediction_results = []
|
@@ -162,17 +165,19 @@ def predict_files(filenames):
|
|
162 |
if os.path.exists(file_path):
|
163 |
# Call make_predictions automatically
|
164 |
prediction_result = make_predictions([file_path]) # Pass file_path as a list
|
165 |
-
prediction_results.
|
166 |
-
|
167 |
-
# Create a copy of prediction_results before deletion
|
168 |
prediction_results_copy = copy.deepcopy(prediction_results)
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
177 |
return render_template('extractor.html', image_paths=image_paths, prediction_results = prediction_results, predictions=dict(zip(image_paths, prediction_results_copy)))
|
178 |
|
|
|
1 |
+
# import warnings
|
2 |
+
|
3 |
+
# # Ignore SourceChangeWarning
|
4 |
+
# warnings.filterwarnings("ignore", category=DeprecationWarning)
|
5 |
+
# warnings.filterwarnings("ignore", category=SourceChangeWarning)
|
6 |
+
|
7 |
+
# Dependencies
|
8 |
from flask import Flask, request, render_template, jsonify, send_file, redirect, url_for, flash, send_from_directory, session
|
9 |
from PIL import Image, ImageDraw
|
10 |
import torch
|
|
|
36 |
from Layoutlmv3_inference.inference_handler import handle
|
37 |
import logging
|
38 |
import os
|
39 |
+
import copy
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
UPLOAD_FOLDER = 'static/temp/uploads'
|
|
|
107 |
temp = None
|
108 |
try:
|
109 |
# For Windows OS
|
110 |
+
temp = pathlib.PosixPath # Save the original state
|
111 |
+
pathlib.PosixPath = pathlib.WindowsPath # Change to WindowsPath temporarily
|
112 |
|
113 |
model_path = Path(r'model/export')
|
114 |
learner = load_learner(model_path)
|
|
|
132 |
except Exception as e:
|
133 |
return {"error in make_predictions": str(e)}
|
134 |
|
135 |
+
finally:
|
136 |
+
pathlib.PosixPath = temp
|
137 |
+
|
138 |
|
|
|
139 |
@app.route('/predict/<filenames>', methods=['GET', 'POST'])
|
140 |
def predict_files(filenames):
|
141 |
prediction_results = []
|
|
|
165 |
if os.path.exists(file_path):
|
166 |
# Call make_predictions automatically
|
167 |
prediction_result = make_predictions([file_path]) # Pass file_path as a list
|
168 |
+
prediction_results.append(prediction_result[0]) # Append only the first prediction result
|
|
|
|
|
169 |
prediction_results_copy = copy.deepcopy(prediction_results)
|
170 |
+
|
171 |
+
non_receipt_indices = []
|
172 |
+
for i, prediction in enumerate(prediction_results):
|
173 |
+
if prediction == 'non-receipt':
|
174 |
+
non_receipt_indices.append(i)
|
175 |
+
|
176 |
+
# Delete images in reverse order to avoid index shifting
|
177 |
+
for index in non_receipt_indices[::-1]:
|
178 |
+
file_to_remove = os.path.join('static', 'temp', 'uploads', image_paths[index])
|
179 |
+
if os.path.exists(file_to_remove):
|
180 |
+
os.remove(file_to_remove)
|
181 |
|
182 |
return render_template('extractor.html', image_paths=image_paths, prediction_results = prediction_results, predictions=dict(zip(image_paths, prediction_results_copy)))
|
183 |
|