import os import json import base64 from PIL import Image from vllm import LLM from vllm.sampling_params import SamplingParams from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor # Optional: Replace with your Hugging Face token or use environment variable hf_token = os.getenv("HF_TOKEN") Image.MAX_IMAGE_PIXELS = None # Initialize Pixtral model model_name = "mistralai/Pixtral-12B-2409" sampling_params = SamplingParams(max_tokens=5000) llm = LLM(model=model_name, tokenizer_mode="mistral", dtype="bfloat16", max_model_len=30000) # Initialize Pix2Struct OCR model ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base") ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") # Load prompt from file def load_prompt(): with open("prompts/prompt.txt", "r", encoding="utf-8") as f: return f.read() # Extract structured JSON from text def try_extract_json(text): if not text or not text.strip(): return None try: return json.loads(text) except json.JSONDecodeError: start = text.find('{') if start == -1: return None brace_count = 0 json_candidate = '' for i in range(start, len(text)): if text[i] == '{': brace_count += 1 elif text[i] == '}': brace_count -= 1 json_candidate += text[i] if brace_count == 0: break try: return json.loads(json_candidate) except json.JSONDecodeError: return None # Base64 encode image def encode_image_as_base64(pil_image): from io import BytesIO buffer = BytesIO() pil_image.save(buffer, format="JPEG") encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded # Extract OCR text using Pix2Struct def extract_all_text_pix2struct(image: Image.Image): inputs = ocr_processor(images=image, return_tensors="pt") predictions = ocr_model.generate(**inputs, max_new_tokens=512) output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True) return output_text.strip() # Assign event/gateway names from OCR text def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str): if not ocr_text or not json_data: return json_data lines = [line.strip() for line in ocr_text.split('\n') if line.strip()] def assign_best_guess(obj): if not obj.get("name") or obj["name"].strip() == "": obj["name"] = "(label unknown)" for evt in json_data.get("events", []): assign_best_guess(evt) for gw in json_data.get("gateways", []): assign_best_guess(gw) return json_data # Run model def run_model(image: Image.Image): prompt = load_prompt() encoded_image = encode_image_as_base64(image) messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}} ] } ] outputs = llm.chat(messages, sampling_params=sampling_params) raw_output = outputs[0].outputs[0].text parsed_json = try_extract_json(raw_output) # Apply OCR post-processing ocr_text = extract_all_text_pix2struct(image) parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text) return { "json": parsed_json, "raw": raw_output }