ARCQUB commited on
Commit
e29cf96
·
verified ·
1 Parent(s): 8093104

Update models/gpt4o.py

Browse files
Files changed (1) hide show
  1. models/gpt4o.py +111 -111
models/gpt4o.py CHANGED
@@ -1,111 +1,111 @@
1
- # gpt4o_pix2struct_ocr.py
2
-
3
- import os
4
- import json
5
- import base64
6
- from PIL import Image
7
- from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
- import numpy as np
9
-
10
- import openai
11
-
12
- model = "gpt-4o"
13
-
14
- # Load Pix2Struct model + processor (vision-language OCR)
15
- processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
16
- pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
17
-
18
-
19
- def load_prompt(prompt_file="/content/vision_model_space/vision_model_space_new/prompts/prompt.txt"):
20
- with open(prompt_file, "r", encoding="utf-8") as f:
21
- return f.read().strip()
22
-
23
-
24
- def try_extract_json(text):
25
- try:
26
- return json.loads(text)
27
- except json.JSONDecodeError:
28
- start = text.find('{')
29
- if start == -1:
30
- return None
31
- brace_count = 0
32
- json_candidate = ''
33
- for i in range(start, len(text)):
34
- if text[i] == '{':
35
- brace_count += 1
36
- elif text[i] == '}':
37
- brace_count -= 1
38
- json_candidate += text[i]
39
- if brace_count == 0 and json_candidate.strip():
40
- break
41
- try:
42
- return json.loads(json_candidate)
43
- except json.JSONDecodeError:
44
- return None
45
-
46
-
47
- def encode_image_base64(image: Image.Image):
48
- from io import BytesIO
49
- buffer = BytesIO()
50
- image.save(buffer, format="JPEG")
51
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
52
-
53
-
54
- def extract_all_text_pix2struct(image: Image.Image):
55
- inputs = processor(images=image, return_tensors="pt")
56
- predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
57
- output_text = processor.decode(predictions[0], skip_special_tokens=True)
58
- return output_text.strip()
59
-
60
-
61
- # Optional: assign best-matching label from full extracted text using proximity (simplified version)
62
- def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
63
- if not ocr_text:
64
- return json_data
65
-
66
- # You could use NLP matching or regex in complex cases
67
- words = ocr_text.split()
68
-
69
- def guess_name_fallback(obj):
70
- if not obj.get("name") or obj["name"].strip() == "":
71
- obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented
72
-
73
- for evt in json_data.get("events", []):
74
- guess_name_fallback(evt)
75
-
76
- for gw in json_data.get("gateways", []):
77
- guess_name_fallback(gw)
78
-
79
- return json_data
80
-
81
-
82
- def run_model(image: Image.Image, api_key: str = None):
83
- prompt_text = load_prompt()
84
- encoded_image = encode_image_base64(image)
85
-
86
- if not api_key:
87
- return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."}
88
-
89
- client = openai.OpenAI(api_key=api_key)
90
- response = client.chat.completions.create(
91
- model=model,
92
- messages=[
93
- {
94
- "role": "user",
95
- "content": [
96
- {"type": "text", "text": prompt_text},
97
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
98
- ]
99
- }
100
- ],
101
- max_tokens=5000
102
- )
103
-
104
- output_text = response.choices[0].message.content.strip()
105
- parsed_json = try_extract_json(output_text)
106
-
107
- # Vision-language OCR assist step (Pix2Struct)
108
- full_ocr_text = extract_all_text_pix2struct(image)
109
- parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
110
-
111
- return {"json": parsed_json, "raw": output_text}
 
1
+ # gpt4o_pix2struct_ocr.py
2
+
3
+ import os
4
+ import json
5
+ import base64
6
+ from PIL import Image
7
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
+ import numpy as np
9
+
10
+ import openai
11
+
12
+ model = "gpt-4o"
13
+
14
+ # Load Pix2Struct model + processor (vision-language OCR)
15
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
16
+ pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
17
+
18
+
19
+ def load_prompt(prompt_file="prompts/prompt.txt"):
20
+ with open(prompt_file, "r", encoding="utf-8") as f:
21
+ return f.read().strip()
22
+
23
+
24
+ def try_extract_json(text):
25
+ try:
26
+ return json.loads(text)
27
+ except json.JSONDecodeError:
28
+ start = text.find('{')
29
+ if start == -1:
30
+ return None
31
+ brace_count = 0
32
+ json_candidate = ''
33
+ for i in range(start, len(text)):
34
+ if text[i] == '{':
35
+ brace_count += 1
36
+ elif text[i] == '}':
37
+ brace_count -= 1
38
+ json_candidate += text[i]
39
+ if brace_count == 0 and json_candidate.strip():
40
+ break
41
+ try:
42
+ return json.loads(json_candidate)
43
+ except json.JSONDecodeError:
44
+ return None
45
+
46
+
47
+ def encode_image_base64(image: Image.Image):
48
+ from io import BytesIO
49
+ buffer = BytesIO()
50
+ image.save(buffer, format="JPEG")
51
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
52
+
53
+
54
+ def extract_all_text_pix2struct(image: Image.Image):
55
+ inputs = processor(images=image, return_tensors="pt")
56
+ predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
57
+ output_text = processor.decode(predictions[0], skip_special_tokens=True)
58
+ return output_text.strip()
59
+
60
+
61
+ # Optional: assign best-matching label from full extracted text using proximity (simplified version)
62
+ def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
63
+ if not ocr_text:
64
+ return json_data
65
+
66
+ # You could use NLP matching or regex in complex cases
67
+ words = ocr_text.split()
68
+
69
+ def guess_name_fallback(obj):
70
+ if not obj.get("name") or obj["name"].strip() == "":
71
+ obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented
72
+
73
+ for evt in json_data.get("events", []):
74
+ guess_name_fallback(evt)
75
+
76
+ for gw in json_data.get("gateways", []):
77
+ guess_name_fallback(gw)
78
+
79
+ return json_data
80
+
81
+
82
+ def run_model(image: Image.Image, api_key: str = None):
83
+ prompt_text = load_prompt()
84
+ encoded_image = encode_image_base64(image)
85
+
86
+ if not api_key:
87
+ return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."}
88
+
89
+ client = openai.OpenAI(api_key=api_key)
90
+ response = client.chat.completions.create(
91
+ model=model,
92
+ messages=[
93
+ {
94
+ "role": "user",
95
+ "content": [
96
+ {"type": "text", "text": prompt_text},
97
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
98
+ ]
99
+ }
100
+ ],
101
+ max_tokens=5000
102
+ )
103
+
104
+ output_text = response.choices[0].message.content.strip()
105
+ parsed_json = try_extract_json(output_text)
106
+
107
+ # Vision-language OCR assist step (Pix2Struct)
108
+ full_ocr_text = extract_all_text_pix2struct(image)
109
+ parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
110
+
111
+ return {"json": parsed_json, "raw": output_text}