Spaces:
Running
Running
# gpt4o_pix2struct_ocr.py | |
import os | |
import json | |
import base64 | |
from PIL import Image | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
import numpy as np | |
import openai | |
model = "gpt-4o" | |
# Load Pix2Struct model + processor (vision-language OCR) | |
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base") | |
pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") | |
def load_prompt(prompt_file="prompts/prompt.txt"): | |
with open(prompt_file, "r", encoding="utf-8") as f: | |
return f.read().strip() | |
def try_extract_json(text): | |
try: | |
return json.loads(text) | |
except json.JSONDecodeError: | |
start = text.find('{') | |
if start == -1: | |
return None | |
brace_count = 0 | |
json_candidate = '' | |
for i in range(start, len(text)): | |
if text[i] == '{': | |
brace_count += 1 | |
elif text[i] == '}': | |
brace_count -= 1 | |
json_candidate += text[i] | |
if brace_count == 0 and json_candidate.strip(): | |
break | |
try: | |
return json.loads(json_candidate) | |
except json.JSONDecodeError: | |
return None | |
def encode_image_base64(image: Image.Image): | |
from io import BytesIO | |
buffer = BytesIO() | |
image.save(buffer, format="JPEG") | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
def extract_all_text_pix2struct(image: Image.Image): | |
inputs = processor(images=image, return_tensors="pt") | |
predictions = pix2struct_model.generate(**inputs, max_new_tokens=512) | |
output_text = processor.decode(predictions[0], skip_special_tokens=True) | |
return output_text.strip() | |
# Optional: assign best-matching label from full extracted text using proximity (simplified version) | |
def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text): | |
if not ocr_text: | |
return json_data | |
# You could use NLP matching or regex in complex cases | |
words = ocr_text.split() | |
def guess_name_fallback(obj): | |
if not obj.get("name") or obj["name"].strip() == "": | |
obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented | |
for evt in json_data.get("events", []): | |
guess_name_fallback(evt) | |
for gw in json_data.get("gateways", []): | |
guess_name_fallback(gw) | |
return json_data | |
def run_model(image: Image.Image, api_key: str = None): | |
prompt_text = load_prompt() | |
encoded_image = encode_image_base64(image) | |
if not api_key: | |
return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."} | |
client = openai.OpenAI(api_key=api_key) | |
response = client.chat.completions.create( | |
model=model, | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt_text}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}} | |
] | |
} | |
], | |
max_tokens=5000 | |
) | |
output_text = response.choices[0].message.content.strip() | |
parsed_json = try_extract_json(output_text) | |
# Vision-language OCR assist step (Pix2Struct) | |
full_ocr_text = extract_all_text_pix2struct(image) | |
parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text) | |
return {"json": parsed_json, "raw": output_text} | |