File size: 3,969 Bytes
8093104
 
 
 
 
 
 
 
74e0e2d
8093104
 
74e0e2d
 
 
8093104
74e0e2d
 
8093104
 
 
7440d16
 
 
8093104
74e0e2d
 
8093104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e0e2d
 
8093104
74e0e2d
 
 
 
 
 
 
 
 
8093104
 
 
 
74e0e2d
 
8093104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e0e2d
 
8093104
74e0e2d
 
 
 
 
 
 
 
 
8093104
 
 
 
 
 
 
 
 
 
 
 
74e0e2d
8093104
 
 
 
 
 
74e0e2d
8093104
74e0e2d
8093104
 
 
 
 
 
74e0e2d
8093104
 
 
 
 
 
74e0e2d
8093104
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import json
import re
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

# Set Hugging Face Token from env
hf_token = os.getenv("HF_TOKEN")

# Lazy-load model objects
aya_model = None
aya_processor = None

ocr_model = None
ocr_processor = None

# Load prompt
def load_prompt():
    #with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
    #    return f.read()
    return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")


# Try extracting JSON from text
def try_extract_json(text):
    if not text or not text.strip():
        return None
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find('{')
        if start == -1:
            return None
        brace_count = 0
        json_candidate = ''
        for i in range(start, len(text)):
            char = text[i]
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
            json_candidate += char
            if brace_count == 0:
                break
        try:
            return json.loads(json_candidate)
        except json.JSONDecodeError:
            return None


# OCR text from Pix2Struct
def extract_all_text_pix2struct(image: Image.Image):
    global ocr_processor, ocr_model

    if ocr_processor is None or ocr_model is None:
        ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
        ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        ocr_model = ocr_model.to(device)

    inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
    predictions = ocr_model.generate(**inputs, max_new_tokens=512)
    output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
    return output_text.strip()


# Add fallback names if missing
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
    if not ocr_text or not json_data:
        return json_data

    def assign_best_guess(obj):
        if not obj.get("name") or obj["name"].strip() == "":
            obj["name"] = "(label unknown)"

    for evt in json_data.get("events", []):
        assign_best_guess(evt)

    for gw in json_data.get("gateways", []):
        assign_best_guess(gw)

    return json_data


# Main inference function
def run_model(image: Image.Image):
    global aya_model, aya_processor

    if aya_model is None or aya_processor is None:
        model_id = "CohereForAI/aya-vision-8b"
        aya_processor = AutoProcessor.from_pretrained(model_id)
        aya_model = AutoModelForImageTextToText.from_pretrained(
            model_id, device_map="auto", torch_dtype=torch.float16
        )

    prompt = load_prompt()

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    inputs = aya_processor.apply_chat_template(
        messages,
        padding=True,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(aya_model.device)

    gen_tokens = aya_model.generate(
        **inputs,
        max_new_tokens=5000,
        do_sample=True,
        temperature=0.3,
    )

    output_text = aya_processor.tokenizer.decode(
        gen_tokens[0][inputs.input_ids.shape[1]:],
        skip_special_tokens=True
    )

    parsed_json = try_extract_json(output_text)

    # OCR enhancement
    ocr_text = extract_all_text_pix2struct(image)
    parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)

    return {
        "json": parsed_json,
        "raw": output_text
    }