Spaces:
Sleeping
Sleeping
Update app/utils.py
Browse files- app/utils.py +129 -36
app/utils.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
-
# utils.py
|
2 |
import os
|
|
|
3 |
from transformers import AutoModel, AutoTokenizer
|
4 |
from PIL import Image, ImageEnhance, ImageFilter
|
5 |
import torch
|
6 |
import logging
|
7 |
from transformers import BertTokenizer
|
|
|
|
|
|
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
@@ -21,9 +24,8 @@ class OCRModel:
|
|
21 |
try:
|
22 |
logger.info("Initializing OCR model...")
|
23 |
|
24 |
-
#
|
25 |
try:
|
26 |
-
# First try with the standard approach
|
27 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
28 |
'stepfun-ai/GOT-OCR2_0',
|
29 |
trust_remote_code=True,
|
@@ -31,7 +33,6 @@ class OCRModel:
|
|
31 |
)
|
32 |
except Exception as e:
|
33 |
logger.warning(f"Standard tokenizer failed, trying BertTokenizer: {str(e)}")
|
34 |
-
# Fall back to BertTokenizer if AutoTokenizer fails
|
35 |
self.tokenizer = BertTokenizer.from_pretrained(
|
36 |
'stepfun-ai/GOT-OCR2_0',
|
37 |
trust_remote_code=True
|
@@ -55,25 +56,24 @@ class OCRModel:
|
|
55 |
raise
|
56 |
|
57 |
def preprocess_image(self, image):
|
58 |
-
"""
|
59 |
try:
|
60 |
-
# Convert image to RGB if it is not already
|
61 |
if image.mode != 'RGB':
|
62 |
image = image.convert('RGB')
|
63 |
|
64 |
-
#
|
65 |
enhancer = ImageEnhance.Contrast(image)
|
66 |
image = enhancer.enhance(1.5)
|
67 |
|
68 |
-
#
|
69 |
enhancer = ImageEnhance.Sharpness(image)
|
70 |
image = enhancer.enhance(1.5)
|
71 |
|
72 |
-
#
|
73 |
enhancer = ImageEnhance.Brightness(image)
|
74 |
image = enhancer.enhance(1.2)
|
75 |
|
76 |
-
#
|
77 |
image = image.filter(ImageFilter.SMOOTH)
|
78 |
|
79 |
return image
|
@@ -81,38 +81,131 @@ class OCRModel:
|
|
81 |
logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True)
|
82 |
raise
|
83 |
|
84 |
-
def process_image(self,
|
85 |
try:
|
86 |
logger.info("Starting image processing")
|
87 |
|
88 |
-
#
|
89 |
-
temp_image_path = "temp_image.jpg"
|
90 |
-
|
91 |
-
# Reset the start pointer for BytesIO
|
92 |
-
image_stream.seek(0)
|
93 |
-
|
94 |
-
# Open and save the image temporarily.
|
95 |
-
image = Image.open(image_stream).convert('RGB')
|
96 |
processed_image = self.preprocess_image(image)
|
|
|
|
|
|
|
97 |
processed_image.save(temp_image_path)
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
except Exception as e:
|
117 |
-
logger.error(f"Error
|
118 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import pandas as pd
|
3 |
from transformers import AutoModel, AutoTokenizer
|
4 |
from PIL import Image, ImageEnhance, ImageFilter
|
5 |
import torch
|
6 |
import logging
|
7 |
from transformers import BertTokenizer
|
8 |
+
import nltk
|
9 |
+
import requests
|
10 |
+
import io
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
24 |
try:
|
25 |
logger.info("Initializing OCR model...")
|
26 |
|
27 |
+
# محاولة تحميل النموذج
|
28 |
try:
|
|
|
29 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
30 |
'stepfun-ai/GOT-OCR2_0',
|
31 |
trust_remote_code=True,
|
|
|
33 |
)
|
34 |
except Exception as e:
|
35 |
logger.warning(f"Standard tokenizer failed, trying BertTokenizer: {str(e)}")
|
|
|
36 |
self.tokenizer = BertTokenizer.from_pretrained(
|
37 |
'stepfun-ai/GOT-OCR2_0',
|
38 |
trust_remote_code=True
|
|
|
56 |
raise
|
57 |
|
58 |
def preprocess_image(self, image):
|
59 |
+
"""تحسين جودة الصورة لتحسين استخراج النص"""
|
60 |
try:
|
|
|
61 |
if image.mode != 'RGB':
|
62 |
image = image.convert('RGB')
|
63 |
|
64 |
+
# تحسين التباين
|
65 |
enhancer = ImageEnhance.Contrast(image)
|
66 |
image = enhancer.enhance(1.5)
|
67 |
|
68 |
+
# تحسين الحدة
|
69 |
enhancer = ImageEnhance.Sharpness(image)
|
70 |
image = enhancer.enhance(1.5)
|
71 |
|
72 |
+
# تحسين السطوع
|
73 |
enhancer = ImageEnhance.Brightness(image)
|
74 |
image = enhancer.enhance(1.2)
|
75 |
|
76 |
+
# تطبيق فلتر لتليين الصورة
|
77 |
image = image.filter(ImageFilter.SMOOTH)
|
78 |
|
79 |
return image
|
|
|
81 |
logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True)
|
82 |
raise
|
83 |
|
84 |
+
def process_image(self, image):
|
85 |
try:
|
86 |
logger.info("Starting image processing")
|
87 |
|
88 |
+
# معالجة الصورة
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
processed_image = self.preprocess_image(image)
|
90 |
+
|
91 |
+
# حفظ الصورة مؤقتاً للتعامل مع النموذج
|
92 |
+
temp_image_path = "temp_ocr_image.jpg"
|
93 |
processed_image.save(temp_image_path)
|
94 |
|
95 |
+
# استخراج النص
|
96 |
+
result = self.model.chat(self.tokenizer, temp_image_path, ocr_type='format')
|
97 |
+
logger.info(f"Successfully extracted text: {result[:100]}...")
|
98 |
+
|
99 |
+
# حذف الملف المؤقت
|
100 |
+
if os.path.exists(temp_image_path):
|
101 |
+
os.remove(temp_image_path)
|
102 |
+
|
103 |
+
return result.strip()
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
logger.error(f"Error in OCR processing: {str(e)}", exc_info=True)
|
107 |
+
if os.path.exists(temp_image_path):
|
108 |
+
os.remove(temp_image_path)
|
109 |
+
return f"Error processing image: {str(e)}"
|
110 |
+
|
111 |
+
class AllergyAnalyzer:
|
112 |
+
def __init__(self, dataset_path):
|
113 |
+
self.dataset_path = dataset_path
|
114 |
+
self.allergy_dict = self.load_allergy_data()
|
115 |
+
nltk.download('punkt', quiet=True)
|
116 |
+
|
117 |
+
def load_allergy_data(self):
|
118 |
+
"""تحميل بيانات الحساسيات من ملف Excel"""
|
119 |
+
try:
|
120 |
+
df = pd.read_excel(self.dataset_path)
|
121 |
+
allergy_dict = {}
|
122 |
+
|
123 |
+
for index, row in df.iterrows():
|
124 |
+
allergy = row['Allergy']
|
125 |
+
ingredients = [ingredient for ingredient in row[1:] if pd.notna(ingredient)]
|
126 |
+
allergy_dict[allergy] = ingredients
|
127 |
|
128 |
+
return allergy_dict
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error loading allergy data: {str(e)}", exc_info=True)
|
131 |
+
return {}
|
132 |
+
|
133 |
+
def tokenize_text(self, text):
|
134 |
+
"""تقسيم النص إلى كلمات"""
|
135 |
+
tokens = nltk.word_tokenize(text)
|
136 |
+
return [w.lower() for w in tokens if w.isalpha()]
|
137 |
+
|
138 |
+
def check_database_allergens(self, token, user_allergens):
|
139 |
+
"""التحقق من وجود الحساسيات في قاعدة البيانات"""
|
140 |
+
results = []
|
141 |
+
for allergy in user_allergens:
|
142 |
+
if allergy in self.allergy_dict and token in self.allergy_dict[allergy]:
|
143 |
+
results.append(allergy)
|
144 |
+
return results
|
145 |
+
|
146 |
+
def check_claude_allergens(self, token, allergy, api_key):
|
147 |
+
"""الاستعلام من Claude API عن الحساسيات"""
|
148 |
+
prompt = f"""
|
149 |
+
You are a professional food safety expert. Analyze if '{token}' contains or is derived from {allergy}.
|
150 |
+
|
151 |
+
Respond ONLY with 'Yes' or 'No'. No explanations.
|
152 |
+
"""
|
153 |
+
|
154 |
+
url = "https://api.anthropic.com/v1/messages"
|
155 |
+
headers = {
|
156 |
+
"x-api-key": api_key,
|
157 |
+
"content-type": "application/json",
|
158 |
+
"anthropic-version": "2023-06-01"
|
159 |
+
}
|
160 |
+
|
161 |
+
data = {
|
162 |
+
"model": "claude-3-opus-20240229",
|
163 |
+
"messages": [{"role": "user", "content": prompt}],
|
164 |
+
"max_tokens": 10
|
165 |
+
}
|
166 |
+
|
167 |
+
try:
|
168 |
+
response = requests.post(url, json=data, headers=headers)
|
169 |
+
json_response = response.json()
|
170 |
+
|
171 |
+
if "content" in json_response and isinstance(json_response["content"], list):
|
172 |
+
return json_response["content"][0]["text"].strip().lower() == 'yes'
|
173 |
+
return False
|
174 |
|
175 |
except Exception as e:
|
176 |
+
logger.error(f"Error querying Claude API: {str(e)}")
|
177 |
+
return False
|
178 |
+
|
179 |
+
def analyze_text(self, text, user_allergens, claude_api_key=None):
|
180 |
+
"""تحليل النص للكشف عن الحساسيات"""
|
181 |
+
detected_allergens = set()
|
182 |
+
database_matches = {}
|
183 |
+
claude_matches = {}
|
184 |
+
tokens = self.tokenize_text(text)
|
185 |
+
|
186 |
+
for token in tokens:
|
187 |
+
# التحقق من قاعدة البيانات أولاً
|
188 |
+
db_results = self.check_database_allergens(token, user_allergens)
|
189 |
+
|
190 |
+
if db_results:
|
191 |
+
for allergy in db_results:
|
192 |
+
detected_allergens.add(allergy)
|
193 |
+
database_matches[allergy] = database_matches.get(allergy, []) + [token]
|
194 |
+
else:
|
195 |
+
# إذا لم توجد في قاعدة البيانات، نستخدم Claude API
|
196 |
+
if claude_api_key:
|
197 |
+
for allergy in user_allergens:
|
198 |
+
if self.check_claude_allergens(token, allergy, claude_api_key):
|
199 |
+
detected_allergens.add(allergy)
|
200 |
+
claude_matches[allergy] = claude_matches.get(allergy, []) + [token]
|
201 |
+
|
202 |
+
return {
|
203 |
+
"detected_allergens": list(detected_allergens),
|
204 |
+
"database_matches": database_matches,
|
205 |
+
"claude_matches": claude_matches,
|
206 |
+
"analyzed_tokens": tokens
|
207 |
+
}
|
208 |
+
|
209 |
+
def get_allergen_list(self):
|
210 |
+
"""الحصول على قائمة الحساسيات المعروفة"""
|
211 |
+
return list(self.allergy_dict.keys())
|