import os import json from PIL import Image import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor from qwen_vl_utils import process_vision_info # Globals (lazy-loaded at runtime) qwen_model = None qwen_processor = None ocr_model = None ocr_processor = None def load_prompt(): #with open("prompts/prompt.txt", "r", encoding="utf-8") as f: # return f.read() return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.") def try_extract_json(text): try: return json.loads(text) except json.JSONDecodeError: start = text.find('{') if start == -1: return text brace_count = 0 json_candidate = '' for i in range(start, len(text)): if text[i] == '{': brace_count += 1 elif text[i] == '}': brace_count -= 1 json_candidate += text[i] if brace_count == 0: break try: return json.loads(json_candidate) except json.JSONDecodeError: return text def extract_all_text_pix2struct(image: Image.Image): global ocr_model, ocr_processor if ocr_model is None or ocr_processor is None: ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base") ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") device = "cuda" if torch.cuda.is_available() else "cpu" ocr_model = ocr_model.to(device) inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device) predictions = ocr_model.generate(**inputs, max_new_tokens=512) return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip() def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str): if not ocr_text or not json_data: return json_data def assign_best_guess(obj): if not obj.get("name") or obj["name"].strip() == "": obj["name"] = "(label unknown)" for evt in json_data.get("events", []): assign_best_guess(evt) for gw in json_data.get("gateways", []): assign_best_guess(gw) return json_data def run_model(image: Image.Image): global qwen_model, qwen_processor if qwen_model is None or qwen_processor is None: qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto" # You can enable flash attention here if needed: # attn_implementation="flash_attention_2" ) min_pixels = 256 * 28 * 28 max_pixels = 1080 * 28 * 28 qwen_processor = AutoProcessor.from_pretrained( "Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels ) prompt = load_prompt() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = qwen_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to(qwen_model.device) generated_ids = qwen_model.generate(**inputs, max_new_tokens=5000) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = qwen_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] parsed_json = try_extract_json(output_text) # OCR post-processing ocr_text = extract_all_text_pix2struct(image) parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text) return { "json": parsed_json, "raw": output_text }