Spaces:
Sleeping
Sleeping
File size: 3,641 Bytes
462e5c0 4c39f47 462e5c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import json
import base64
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from vllm import LLM
from vllm.sampling_params import SamplingParams
# Hugging Face token from environment (optional)
hf_token = os.getenv("HF_TOKEN")
Image.MAX_IMAGE_PIXELS = None
# Global placeholders (lazy-loaded later)
llm = None
ocr_model = None
ocr_processor = None
sampling_params = SamplingParams(max_tokens=5000)
def load_prompt():
#with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
# return f.read()
return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")
def try_extract_json(text):
if not text or not text.strip():
return None
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find('{')
if start == -1:
return None
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
json_candidate += text[i]
if brace_count == 0:
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return None
def encode_image_as_base64(pil_image):
from io import BytesIO
buffer = BytesIO()
pil_image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def extract_all_text_pix2struct(image: Image.Image):
global ocr_processor, ocr_model
if ocr_processor is None or ocr_model is None:
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_model = ocr_model.to(device)
inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
if not ocr_text or not json_data:
return json_data
def assign_best_guess(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)"
for evt in json_data.get("events", []):
assign_best_guess(evt)
for gw in json_data.get("gateways", []):
assign_best_guess(gw)
return json_data
def run_model(image: Image.Image):
global llm
if llm is None:
llm = LLM(
model="mistralai/Pixtral-12B-2409",
tokenizer_mode="mistral",
dtype="bfloat16",
max_model_len=30000,
)
prompt = load_prompt()
encoded_image = encode_image_as_base64(image)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
]
}
]
outputs = llm.chat(messages, sampling_params=sampling_params)
raw_output = outputs[0].outputs[0].text
parsed_json = try_extract_json(raw_output)
# Apply OCR enrichment
ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
return {
"json": parsed_json,
"raw": raw_output
}
|