mike23415 commited on
Commit
c72429b
·
verified ·
1 Parent(s): 363cac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -22
app.py CHANGED
@@ -4,7 +4,7 @@ from flask import Flask, request, jsonify
4
  from flask_cors import CORS
5
  from werkzeug.utils import secure_filename
6
  import torch
7
- from transformers import DonutProcessor, VisionEncoderDecoderModel
8
  from PIL import Image
9
  import cv2
10
  import numpy as np
@@ -21,13 +21,14 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload
21
  # Create uploads directory if it doesn't exist
22
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
 
24
- # Load OCR model - Microsoft's Donut model
25
  print("Loading OCR model...")
26
- processor = DonutProcessor.from_pretrained("microsoft/donut-base", cache_dir="/huggingface_cache")
27
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/donut-base", cache_dir="/huggingface_cache")
28
 
29
  # Move model to GPU if available
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
31
  model.to(device)
32
 
33
  def allowed_file(filename):
@@ -60,27 +61,16 @@ def perform_ocr(image_path):
60
  pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
61
 
62
  # Generate text
63
- task_prompt = "<s_ocr>"
64
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
65
-
66
- outputs = model.generate(
67
  pixel_values,
68
- decoder_input_ids=decoder_input_ids,
69
- max_length=model.decoder.config.max_position_embeddings,
70
- early_stopping=True,
71
- pad_token_id=processor.tokenizer.pad_token_id,
72
- eos_token_id=processor.tokenizer.eos_token_id,
73
- use_cache=True,
74
  num_beams=5,
75
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
76
- return_dict_in_generate=True,
77
  )
78
 
79
  # Decode generated text
80
- sequence = processor.batch_decode(outputs.sequences)[0]
81
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
82
- sequence = sequence.replace("<s>", "").replace("</s>", "").replace("<s_ocr>", "").replace("</s_ocr>", "")
83
- return sequence.strip()
84
 
85
  @app.route('/ocr', methods=['POST'])
86
  def ocr():
@@ -107,14 +97,16 @@ def ocr():
107
  # Perform OCR
108
  extracted_text = perform_ocr(file_path)
109
 
110
- # Clean up the file if needed
111
- # os.remove(file_path)
112
 
113
  return jsonify({
114
  'success': True,
115
  'text': extracted_text
116
  })
117
  except Exception as e:
 
 
118
  return jsonify({
119
  'success': False,
120
  'error': str(e)
 
4
  from flask_cors import CORS
5
  from werkzeug.utils import secure_filename
6
  import torch
7
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
  from PIL import Image
9
  import cv2
10
  import numpy as np
 
21
  # Create uploads directory if it doesn't exist
22
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
 
24
+ # Load OCR model - Microsoft's TrOCR model
25
  print("Loading OCR model...")
26
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten", cache_dir="/huggingface_cache")
27
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten", cache_dir="/huggingface_cache")
28
 
29
  # Move model to GPU if available
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ print(f"Using device: {device}")
32
  model.to(device)
33
 
34
  def allowed_file(filename):
 
61
  pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
62
 
63
  # Generate text
64
+ generated_ids = model.generate(
 
 
 
65
  pixel_values,
66
+ max_length=64,
 
 
 
 
 
67
  num_beams=5,
68
+ early_stopping=True
 
69
  )
70
 
71
  # Decode generated text
72
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
73
+ return generated_text.strip()
 
 
74
 
75
  @app.route('/ocr', methods=['POST'])
76
  def ocr():
 
97
  # Perform OCR
98
  extracted_text = perform_ocr(file_path)
99
 
100
+ # Clean up the file after processing
101
+ os.remove(file_path)
102
 
103
  return jsonify({
104
  'success': True,
105
  'text': extracted_text
106
  })
107
  except Exception as e:
108
+ # Log the error
109
+ print(f"Error processing image: {str(e)}")
110
  return jsonify({
111
  'success': False,
112
  'error': str(e)