Scezui commited on
Commit
0e54ad9
·
1 Parent(s): 8d024e6
Files changed (1) hide show
  1. Layoutlmv3_inference/ocr.py +15 -129
Layoutlmv3_inference/ocr.py CHANGED
@@ -6,9 +6,12 @@ import json
6
  import requests
7
  import traceback
8
  import tempfile
 
 
9
 
10
  from PIL import Image
11
 
 
12
  def preprocess_image(image_path, max_file_size_mb=1, target_file_size_mb=0.5):
13
  try:
14
  # Read the image
@@ -21,10 +24,12 @@ def preprocess_image(image_path, max_file_size_mb=1, target_file_size_mb=0.5):
21
  cv2.imwrite(temp_file_path, enhanced)
22
 
23
  # Check file size of the temporary file
24
- file_size_mb = os.path.getsize(temp_file_path) / (1024 * 1024) # Convert to megabytes
 
25
 
26
  while file_size_mb > max_file_size_mb:
27
- print(f"File size ({file_size_mb} MB) exceeds the maximum allowed size ({max_file_size_mb} MB). Resizing the image.")
 
28
  ratio = np.sqrt(target_file_size_mb / file_size_mb)
29
  new_width = int(image.shape[1] * ratio)
30
  new_height = int(image.shape[0] * ratio)
@@ -63,7 +68,11 @@ def enhance_txt(img, intensity_increase=20, bilateral_filter_diameter=9, bilater
63
  # Convert image to grayscale
64
  grayscale_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
65
 
66
- # Apply Gaussian blur
 
 
 
 
67
  blurred = cv2.GaussianBlur(grayscale_img, (1, 1), 0)
68
 
69
  edged = 255 - cv2.Canny(blurred, 100, 150, apertureSize=7)
@@ -72,7 +81,8 @@ def enhance_txt(img, intensity_increase=20, bilateral_filter_diameter=9, bilater
72
  img = np.clip(img + intensity_increase, 0, 255).astype(np.uint8)
73
 
74
  # Apply bilateral filter to reduce noise
75
- img = cv2.bilateralFilter(img, bilateral_filter_diameter, bilateral_filter_sigma_color, bilateral_filter_sigma_space)
 
76
 
77
  _, binary = cv2.threshold(blurred, threshold, 255, cv2.THRESH_BINARY)
78
  return binary
@@ -91,6 +101,7 @@ def run_tesseract_on_preprocessed_image(preprocessed_image, image_path):
91
  url = "https://api.ocr.space/parse/image"
92
 
93
  # Define the API key and the language
 
94
  api_key = os.getenv("ocr_space")
95
  language = "eng"
96
 
@@ -210,128 +221,3 @@ def prepare_batch_for_inference(image_paths):
210
 
211
  print("10. Prepared for Inference Batch")
212
  return inference_batch
213
- al_filter_diameter, bilateral_filter_sigma_color, bilateral_filter_sigma_space)
214
-
215
- _, binary = cv2.threshold(blurred, threshold, 255, cv2.THRESH_BINARY)
216
- return binary
217
-
218
-
219
- def run_tesseract_on_preprocessed_image(preprocessed_image, image_path):
220
- try:
221
- image_name = os.path.basename(image_path)
222
- image_name = image_name[:image_name.find('.')]
223
-
224
- # Create the "temp" folder if it doesn't exist
225
- temp_folder = "static/temp"
226
- if not os.path.exists(temp_folder):
227
- os.makedirs(temp_folder)
228
-
229
- # Define the OCR API endpoint
230
- url = "https://api.ocr.space/parse/image"
231
-
232
- # Define the API key and the language
233
- api_key = os.getenv("ocr_space")
234
- language = "eng"
235
-
236
- # Save the preprocessed image
237
- cv2.imwrite(os.path.join(temp_folder, f"{image_name}_preprocessed.jpg"), preprocessed_image)
238
-
239
- # Open the preprocessed image file as binary
240
- with open(os.path.join(temp_folder, f"{image_name}_preprocessed.jpg"), "rb") as f:
241
- # Define the payload for the API request
242
- payload = {
243
- "apikey": api_key,
244
- "language": language,
245
- "isOverlayRequired": True,
246
- "OCREngine": 2
247
- }
248
- # Define the file parameter for the API request
249
- file = {
250
- "file": f
251
- }
252
- # Send the POST request to the OCR API
253
- response = requests.post(url, data=payload, files=file)
254
-
255
- # Check the status code of the response
256
- if response.status_code == 200:
257
- # Parse the JSON response
258
- result = response.json()
259
- print("---JSON file saved")
260
- # Save the OCR result as JSON
261
- with open(os.path.join(temp_folder, f"{image_name}_ocr.json"), 'w') as f:
262
- json.dump(result, f)
263
-
264
- return os.path.join(temp_folder, f"{image_name}_ocr.json")
265
- else:
266
- # Print the error message
267
- print("Error: " + response.text)
268
- return None
269
-
270
- except Exception as e:
271
- print(f"An error occurred during OCR request: {str(e)}")
272
- return None
273
-
274
- def clean_tesseract_output(json_output_path):
275
- try:
276
- with open(json_output_path, 'r') as json_file:
277
- data = json.load(json_file)
278
-
279
- lines = data['ParsedResults'][0]['TextOverlay']['Lines']
280
-
281
- words = []
282
- for line in lines:
283
- for word_info in line['Words']:
284
- word = {}
285
- origin_box = [
286
- word_info['Left'],
287
- word_info['Top'],
288
- word_info['Left'] + word_info['Width'],
289
- word_info['Top'] + word_info['Height']
290
- ]
291
-
292
- word['word_text'] = word_info['WordText']
293
- word['word_box'] = origin_box
294
- words.append(word)
295
-
296
- return words
297
- except (KeyError, IndexError, FileNotFoundError, json.JSONDecodeError) as e:
298
- print(f"Error cleaning Tesseract output: {str(e)}")
299
- return None
300
-
301
- def prepare_batch_for_inference(image_paths):
302
- # print("my_function was called")
303
- # traceback.print_stack() # This will print the stack trace
304
- print(f"Number of images to process: {len(image_paths)}") # Print the total number of images to be processed
305
- print("1. Preparing for Inference")
306
- tsv_output_paths = []
307
-
308
- inference_batch = dict()
309
- print("2. Starting Preprocessing")
310
- # Ensure that the image is only 1
311
- for image_path in image_paths:
312
- print(f"Processing the image: {image_path}") # Print the image being processed
313
- print("3. Preprocessing the Receipt")
314
- preprocessed_image = preprocess_image(image_path)
315
- if preprocessed_image is not None:
316
- print("4. Preprocessing done. Running OCR")
317
- json_output_path = run_tesseract_on_preprocessed_image(preprocessed_image, image_path)
318
- print("5. OCR Complete")
319
- if json_output_path:
320
- tsv_output_paths.append(json_output_path)
321
-
322
- print("6. Preprocessing and OCR Done")
323
- # clean_outputs is a list of lists
324
- clean_outputs = [clean_tesseract_output(tsv_path) for tsv_path in tsv_output_paths]
325
- print("7. Cleaned OCR output")
326
- word_lists = [[word['word_text'] for word in clean_output] for clean_output in clean_outputs]
327
- print("8. Word List Created")
328
- boxes_lists = [[word['word_box'] for word in clean_output] for clean_output in clean_outputs]
329
- print("9. Box List Created")
330
- inference_batch = {
331
- "image_path": image_paths,
332
- "bboxes": boxes_lists,
333
- "words": word_lists
334
- }
335
-
336
- print("10. Prepared for Inference Batch")
337
- return inference_batch
 
6
  import requests
7
  import traceback
8
  import tempfile
9
+ from rembg import remove
10
+
11
 
12
  from PIL import Image
13
 
14
+
15
  def preprocess_image(image_path, max_file_size_mb=1, target_file_size_mb=0.5):
16
  try:
17
  # Read the image
 
24
  cv2.imwrite(temp_file_path, enhanced)
25
 
26
  # Check file size of the temporary file
27
+ file_size_mb = os.path.getsize(
28
+ temp_file_path) / (1024 * 1024) # Convert to megabytes
29
 
30
  while file_size_mb > max_file_size_mb:
31
+ print(
32
+ f"File size ({file_size_mb} MB) exceeds the maximum allowed size ({max_file_size_mb} MB). Resizing the image.")
33
  ratio = np.sqrt(target_file_size_mb / file_size_mb)
34
  new_width = int(image.shape[1] * ratio)
35
  new_height = int(image.shape[0] * ratio)
 
68
  # Convert image to grayscale
69
  grayscale_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
70
 
71
+ # Find contours
72
+ contours, _ = cv2.findContours(
73
+ grayscale_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
74
+
75
+ # # Apply Gaussian blur
76
  blurred = cv2.GaussianBlur(grayscale_img, (1, 1), 0)
77
 
78
  edged = 255 - cv2.Canny(blurred, 100, 150, apertureSize=7)
 
81
  img = np.clip(img + intensity_increase, 0, 255).astype(np.uint8)
82
 
83
  # Apply bilateral filter to reduce noise
84
+ img = cv2.bilateralFilter(img, bilateral_filter_diameter,
85
+ bilateral_filter_sigma_color, bilateral_filter_sigma_space)
86
 
87
  _, binary = cv2.threshold(blurred, threshold, 255, cv2.THRESH_BINARY)
88
  return binary
 
101
  url = "https://api.ocr.space/parse/image"
102
 
103
  # Define the API key and the language
104
+ # api_key = "K88232854988957" # Replace with your actual OCR Space API key
105
  api_key = os.getenv("ocr_space")
106
  language = "eng"
107
 
 
221
 
222
  print("10. Prepared for Inference Batch")
223
  return inference_batch