ARCQUB's picture
Update models/gpt4o.py
e29cf96 verified
raw
history blame
3.5 kB
# gpt4o_pix2struct_ocr.py
import os
import json
import base64
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import numpy as np
import openai
model = "gpt-4o"
# Load Pix2Struct model + processor (vision-language OCR)
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
def load_prompt(prompt_file="prompts/prompt.txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
return f.read().strip()
def try_extract_json(text):
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find('{')
if start == -1:
return None
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
json_candidate += text[i]
if brace_count == 0 and json_candidate.strip():
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return None
def encode_image_base64(image: Image.Image):
from io import BytesIO
buffer = BytesIO()
image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def extract_all_text_pix2struct(image: Image.Image):
inputs = processor(images=image, return_tensors="pt")
predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
output_text = processor.decode(predictions[0], skip_special_tokens=True)
return output_text.strip()
# Optional: assign best-matching label from full extracted text using proximity (simplified version)
def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
if not ocr_text:
return json_data
# You could use NLP matching or regex in complex cases
words = ocr_text.split()
def guess_name_fallback(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented
for evt in json_data.get("events", []):
guess_name_fallback(evt)
for gw in json_data.get("gateways", []):
guess_name_fallback(gw)
return json_data
def run_model(image: Image.Image, api_key: str = None):
prompt_text = load_prompt()
encoded_image = encode_image_base64(image)
if not api_key:
return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."}
client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
]
}
],
max_tokens=5000
)
output_text = response.choices[0].message.content.strip()
parsed_json = try_extract_json(output_text)
# Vision-language OCR assist step (Pix2Struct)
full_ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
return {"json": parsed_json, "raw": output_text}