File size: 3,641 Bytes
462e5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c39f47
 
 
462e5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import json
import base64
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from vllm import LLM
from vllm.sampling_params import SamplingParams

# Hugging Face token from environment (optional)
hf_token = os.getenv("HF_TOKEN")
Image.MAX_IMAGE_PIXELS = None

# Global placeholders (lazy-loaded later)
llm = None
ocr_model = None
ocr_processor = None
sampling_params = SamplingParams(max_tokens=5000)


def load_prompt():
    #with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
    #    return f.read()
    return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")


def try_extract_json(text):
    if not text or not text.strip():
        return None
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find('{')
        if start == -1:
            return None
        brace_count = 0
        json_candidate = ''
        for i in range(start, len(text)):
            if text[i] == '{':
                brace_count += 1
            elif text[i] == '}':
                brace_count -= 1
            json_candidate += text[i]
            if brace_count == 0:
                break
        try:
            return json.loads(json_candidate)
        except json.JSONDecodeError:
            return None


def encode_image_as_base64(pil_image):
    from io import BytesIO
    buffer = BytesIO()
    pil_image.save(buffer, format="JPEG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def extract_all_text_pix2struct(image: Image.Image):
    global ocr_processor, ocr_model

    if ocr_processor is None or ocr_model is None:
        ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
        ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        ocr_model = ocr_model.to(device)

    inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
    predictions = ocr_model.generate(**inputs, max_new_tokens=512)
    return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()


def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
    if not ocr_text or not json_data:
        return json_data

    def assign_best_guess(obj):
        if not obj.get("name") or obj["name"].strip() == "":
            obj["name"] = "(label unknown)"

    for evt in json_data.get("events", []):
        assign_best_guess(evt)

    for gw in json_data.get("gateways", []):
        assign_best_guess(gw)

    return json_data


def run_model(image: Image.Image):
    global llm

    if llm is None:
        llm = LLM(
            model="mistralai/Pixtral-12B-2409",
            tokenizer_mode="mistral",
            dtype="bfloat16",
            max_model_len=30000,
        )

    prompt = load_prompt()
    encoded_image = encode_image_as_base64(image)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
            ]
        }
    ]

    outputs = llm.chat(messages, sampling_params=sampling_params)
    raw_output = outputs[0].outputs[0].text
    parsed_json = try_extract_json(raw_output)

    # Apply OCR enrichment
    ocr_text = extract_all_text_pix2struct(image)
    parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)

    return {
        "json": parsed_json,
        "raw": raw_output
    }