File size: 3,643 Bytes
e29cf96
 
 
 
 
 
 
9ac9cd3
e29cf96
 
 
9ac9cd3
 
e29cf96
 
 
1af8f14
 
 
e29cf96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac9cd3
 
 
 
 
 
 
 
33ad5cc
9ac9cd3
 
e29cf96
 
 
 
 
 
 
 
 
 
 
9ac9cd3
e29cf96
 
 
 
 
 
 
 
 
 
 
 
 
 
5f5b090
e29cf96
5f5b090
e29cf96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac9cd3
e29cf96
 
 
9ac9cd3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# gpt4o_pix2struct_ocr.py

import os
import json
import base64
from PIL import Image
import openai
import torch

model = "gpt-4o"

pix2struct_model = None
processor = None


def load_prompt(prompt_file="prompts/prompt.txt"):
    #with open(prompt_file, "r", encoding="utf-8") as f:
    #    return f.read().strip()
    return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")


def try_extract_json(text):
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find('{')
        if start == -1:
            return None
        brace_count = 0
        json_candidate = ''
        for i in range(start, len(text)):
            if text[i] == '{':
                brace_count += 1
            elif text[i] == '}':
                brace_count -= 1
            json_candidate += text[i]
            if brace_count == 0 and json_candidate.strip():
                break
        try:
            return json.loads(json_candidate)
        except json.JSONDecodeError:
            return None


def encode_image_base64(image: Image.Image):
    from io import BytesIO
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def extract_all_text_pix2struct(image: Image.Image):
    global pix2struct_model, processor

    # Lazy-load the Pix2Struct model
    if processor is None or pix2struct_model is None:
        from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
        processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
        pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(
            "google/pix2struct-textcaps-base"
        ).to("cuda" if torch.cuda.is_available() else "cpu")

    inputs = processor(images=image, return_tensors="pt").to(pix2struct_model.device)
    predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
    output_text = processor.decode(predictions[0], skip_special_tokens=True)
    return output_text.strip()


def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
    if not ocr_text:
        return json_data

    def guess_name_fallback(obj):
        if not obj.get("name") or obj["name"].strip() == "":
            obj["name"] = "(label unknown)"

    for evt in json_data.get("events", []):
        guess_name_fallback(evt)

    for gw in json_data.get("gateways", []):
        guess_name_fallback(gw)

    return json_data


def run_model(image: Image.Image, api_key: str = None):
    prompt_text = load_prompt()
    encoded_image = encode_image_base64(image)

    api_key = api_key or os.getenv("OPENAI_API_KEY")
    if not api_key:
        return {"json": None, "raw": "⚠️ API key is missing. Please set it as a secret in your Space or upload it as a file."}

    client = openai.OpenAI(api_key=api_key)
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt_text},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
                ]
            }
        ],
        max_tokens=5000
    )

    output_text = response.choices[0].message.content.strip()
    parsed_json = try_extract_json(output_text)

    # Use Pix2Struct OCR enrichment
    full_ocr_text = extract_all_text_pix2struct(image)
    parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)

    return {"json": parsed_json, "raw": output_text}