Spaces:
Sleeping
Sleeping
import os | |
import json | |
from PIL import Image | |
import torch | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
from qwen_vl_utils import process_vision_info | |
# Globals (lazy-loaded at runtime) | |
qwen_model = None | |
qwen_processor = None | |
ocr_model = None | |
ocr_processor = None | |
def load_prompt(): | |
#with open("prompts/prompt.txt", "r", encoding="utf-8") as f: | |
# return f.read() | |
return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.") | |
def try_extract_json(text): | |
try: | |
return json.loads(text) | |
except json.JSONDecodeError: | |
start = text.find('{') | |
if start == -1: | |
return text | |
brace_count = 0 | |
json_candidate = '' | |
for i in range(start, len(text)): | |
if text[i] == '{': | |
brace_count += 1 | |
elif text[i] == '}': | |
brace_count -= 1 | |
json_candidate += text[i] | |
if brace_count == 0: | |
break | |
try: | |
return json.loads(json_candidate) | |
except json.JSONDecodeError: | |
return text | |
def extract_all_text_pix2struct(image: Image.Image): | |
global ocr_model, ocr_processor | |
if ocr_model is None or ocr_processor is None: | |
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base") | |
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ocr_model = ocr_model.to(device) | |
inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device) | |
predictions = ocr_model.generate(**inputs, max_new_tokens=512) | |
return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip() | |
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str): | |
if not ocr_text or not json_data: | |
return json_data | |
def assign_best_guess(obj): | |
if not obj.get("name") or obj["name"].strip() == "": | |
obj["name"] = "(label unknown)" | |
for evt in json_data.get("events", []): | |
assign_best_guess(evt) | |
for gw in json_data.get("gateways", []): | |
assign_best_guess(gw) | |
return json_data | |
def run_model(image: Image.Image): | |
global qwen_model, qwen_processor | |
if qwen_model is None or qwen_processor is None: | |
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-7B-Instruct", | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
# You can enable flash attention here if needed: | |
# attn_implementation="flash_attention_2" | |
) | |
min_pixels = 256 * 28 * 28 | |
max_pixels = 1080 * 28 * 28 | |
qwen_processor = AutoProcessor.from_pretrained( | |
"Qwen/Qwen2.5-VL-7B-Instruct", | |
min_pixels=min_pixels, | |
max_pixels=max_pixels | |
) | |
prompt = load_prompt() | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt} | |
] | |
} | |
] | |
text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = qwen_processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt" | |
).to(qwen_model.device) | |
generated_ids = qwen_model.generate(**inputs, max_new_tokens=5000) | |
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | |
output_text = qwen_processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
parsed_json = try_extract_json(output_text) | |
# OCR post-processing | |
ocr_text = extract_all_text_pix2struct(image) | |
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text) | |
return { | |
"json": parsed_json, | |
"raw": output_text | |
} | |