mashaelalbu commited on
Commit
86ffb63
·
verified ·
1 Parent(s): e7d0dbe

update to gpu

Browse files
Files changed (1) hide show
  1. app/utils.py +104 -104
app/utils.py CHANGED
@@ -1,105 +1,105 @@
1
- # utils.py
2
- import os
3
- from transformers import AutoModel, AutoTokenizer
4
- from PIL import Image, ImageEnhance, ImageFilter
5
- import torch
6
- import logging
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- class OCRModel:
11
- _instance = None
12
-
13
- def __new__(cls):
14
- if cls._instance is None:
15
- cls._instance = super(OCRModel, cls).__new__(cls)
16
- cls._instance.initialize()
17
- return cls._instance
18
-
19
- def initialize(self):
20
- try:
21
- logger.info("Initializing OCR model...")
22
-
23
- # Model initilization
24
- self.tokenizer = AutoTokenizer.from_pretrained('RufusRubin777/GOT-OCR2_0_CPU', trust_remote_code=True)
25
- self.model = AutoModel.from_pretrained(
26
- 'RufusRubin777/GOT-OCR2_0_CPU',
27
- trust_remote_code=True,
28
- low_cpu_mem_usage=True,
29
- device_map='cpu',
30
- use_safetensors=True,
31
- pad_token_id=self.tokenizer.eos_token_id
32
- )
33
-
34
- # choose cpu
35
- self.device = "cpu"
36
- self.model = self.model.eval().cpu()
37
-
38
- logger.info("Model initialization completed successfully")
39
-
40
- except Exception as e:
41
- logger.error(f"Error initializing model: {str(e)}", exc_info=True)
42
- raise
43
-
44
- def preprocess_image(self, image):
45
- """Image preprocessing to improve text recognition quality"""
46
- try:
47
- # Convert image to RGB if it is not already
48
- if image.mode != 'RGB':
49
- image = image.convert('RGB')
50
-
51
- # Improve contrast
52
- enhancer = ImageEnhance.Contrast(image)
53
- image = enhancer.enhance(1.5)
54
-
55
- # Improve Sharpness
56
- enhancer = ImageEnhance.Sharpness(image)
57
- image = enhancer.enhance(1.5)
58
-
59
- # Improve Brightness
60
- enhancer = ImageEnhance.Brightness(image)
61
- image = enhancer.enhance(1.2)
62
-
63
- # Apply a filter to soften the image a little.
64
- image = image.filter(ImageFilter.SMOOTH)
65
-
66
- return image
67
- except Exception as e:
68
- logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True)
69
- raise
70
-
71
- def process_image(self, image_stream):
72
- try:
73
- logger.info("Starting image processing")
74
-
75
- # Save image temporarily because the model requires a file path.
76
- temp_image_path = "temp_image.jpg"
77
-
78
- # Reset the start pointer for BytesIO
79
- image_stream.seek(0)
80
-
81
- # Open and save the image temporarily.
82
- image = Image.open(image_stream).convert('RGB')
83
- processed_image = self.preprocess_image(image)
84
- processed_image.save(temp_image_path)
85
-
86
- # ocr
87
- try:
88
- result = self.model.chat(self.tokenizer, temp_image_path, ocr_type='format')
89
- logger.info(f"Successfully extracted text: {result[:100]}...")
90
-
91
- # Delete temporary file
92
- if os.path.exists(temp_image_path):
93
- os.remove(temp_image_path)
94
-
95
- return result.strip()
96
-
97
- except Exception as e:
98
- logger.error(f"Error in OCR processing: {str(e)}", exc_info=True)
99
- if os.path.exists(temp_image_path):
100
- os.remove(temp_image_path)
101
- raise
102
-
103
- except Exception as e:
104
- logger.error(f"Error in image processing: {str(e)}", exc_info=True)
105
  return f"Error processing image: {str(e)}"
 
1
+ # utils.py
2
+ import os
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from PIL import Image, ImageEnhance, ImageFilter
5
+ import torch
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class OCRModel:
11
+ _instance = None
12
+
13
+ def __new__(cls):
14
+ if cls._instance is None:
15
+ cls._instance = super(OCRModel, cls).__new__(cls)
16
+ cls._instance.initialize()
17
+ return cls._instance
18
+
19
+ def initialize(self):
20
+ try:
21
+ logger.info("Initializing OCR model...")
22
+
23
+ # Model initilization
24
+ self.tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR-2.0-hf', trust_remote_code=True)
25
+ self.model = AutoModel.from_pretrained(
26
+ 'stepfun-ai/GOT-OCR-2.0-hf',
27
+ trust_remote_code=True,
28
+ low_cpu_mem_usage=True,
29
+ device_map='cuda',
30
+ use_safetensors=True,
31
+ pad_token_id=self.tokenizer.eos_token_id
32
+ )
33
+
34
+ # choose cpu
35
+ self.device = "cuda"
36
+ self.model = self.model.eval().cpu()
37
+
38
+ logger.info("Model initialization completed successfully")
39
+
40
+ except Exception as e:
41
+ logger.error(f"Error initializing model: {str(e)}", exc_info=True)
42
+ raise
43
+
44
+ def preprocess_image(self, image):
45
+ """Image preprocessing to improve text recognition quality"""
46
+ try:
47
+ # Convert image to RGB if it is not already
48
+ if image.mode != 'RGB':
49
+ image = image.convert('RGB')
50
+
51
+ # Improve contrast
52
+ enhancer = ImageEnhance.Contrast(image)
53
+ image = enhancer.enhance(1.5)
54
+
55
+ # Improve Sharpness
56
+ enhancer = ImageEnhance.Sharpness(image)
57
+ image = enhancer.enhance(1.5)
58
+
59
+ # Improve Brightness
60
+ enhancer = ImageEnhance.Brightness(image)
61
+ image = enhancer.enhance(1.2)
62
+
63
+ # Apply a filter to soften the image a little.
64
+ image = image.filter(ImageFilter.SMOOTH)
65
+
66
+ return image
67
+ except Exception as e:
68
+ logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True)
69
+ raise
70
+
71
+ def process_image(self, image_stream):
72
+ try:
73
+ logger.info("Starting image processing")
74
+
75
+ # Save image temporarily because the model requires a file path.
76
+ temp_image_path = "temp_image.jpg"
77
+
78
+ # Reset the start pointer for BytesIO
79
+ image_stream.seek(0)
80
+
81
+ # Open and save the image temporarily.
82
+ image = Image.open(image_stream).convert('RGB')
83
+ processed_image = self.preprocess_image(image)
84
+ processed_image.save(temp_image_path)
85
+
86
+ # ocr
87
+ try:
88
+ result = self.model.chat(self.tokenizer, temp_image_path, ocr_type='format')
89
+ logger.info(f"Successfully extracted text: {result[:100]}...")
90
+
91
+ # Delete temporary file
92
+ if os.path.exists(temp_image_path):
93
+ os.remove(temp_image_path)
94
+
95
+ return result.strip()
96
+
97
+ except Exception as e:
98
+ logger.error(f"Error in OCR processing: {str(e)}", exc_info=True)
99
+ if os.path.exists(temp_image_path):
100
+ os.remove(temp_image_path)
101
+ raise
102
+
103
+ except Exception as e:
104
+ logger.error(f"Error in image processing: {str(e)}", exc_info=True)
105
  return f"Error processing image: {str(e)}"