Spaces:
Running
Running
import os | |
import json | |
import base64 | |
from PIL import Image | |
from vllm import LLM | |
from vllm.sampling_params import SamplingParams | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
# Optional: Replace with your Hugging Face token or use environment variable | |
hf_token = os.getenv("HF_TOKEN") | |
Image.MAX_IMAGE_PIXELS = None | |
# Initialize Pixtral model | |
model_name = "mistralai/Pixtral-12B-2409" | |
sampling_params = SamplingParams(max_tokens=5000) | |
llm = LLM(model=model_name, tokenizer_mode="mistral", dtype="bfloat16", max_model_len=30000) | |
# Initialize Pix2Struct OCR model | |
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base") | |
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") | |
# Load prompt from file | |
def load_prompt(): | |
with open("prompts/prompt.txt", "r", encoding="utf-8") as f: | |
return f.read() | |
# Extract structured JSON from text | |
def try_extract_json(text): | |
if not text or not text.strip(): | |
return None | |
try: | |
return json.loads(text) | |
except json.JSONDecodeError: | |
start = text.find('{') | |
if start == -1: | |
return None | |
brace_count = 0 | |
json_candidate = '' | |
for i in range(start, len(text)): | |
if text[i] == '{': | |
brace_count += 1 | |
elif text[i] == '}': | |
brace_count -= 1 | |
json_candidate += text[i] | |
if brace_count == 0: | |
break | |
try: | |
return json.loads(json_candidate) | |
except json.JSONDecodeError: | |
return None | |
# Base64 encode image | |
def encode_image_as_base64(pil_image): | |
from io import BytesIO | |
buffer = BytesIO() | |
pil_image.save(buffer, format="JPEG") | |
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
return encoded | |
# Extract OCR text using Pix2Struct | |
def extract_all_text_pix2struct(image: Image.Image): | |
inputs = ocr_processor(images=image, return_tensors="pt") | |
predictions = ocr_model.generate(**inputs, max_new_tokens=512) | |
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True) | |
return output_text.strip() | |
# Assign event/gateway names from OCR text | |
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str): | |
if not ocr_text or not json_data: | |
return json_data | |
lines = [line.strip() for line in ocr_text.split('\n') if line.strip()] | |
def assign_best_guess(obj): | |
if not obj.get("name") or obj["name"].strip() == "": | |
obj["name"] = "(label unknown)" | |
for evt in json_data.get("events", []): | |
assign_best_guess(evt) | |
for gw in json_data.get("gateways", []): | |
assign_best_guess(gw) | |
return json_data | |
# Run model | |
def run_model(image: Image.Image): | |
prompt = load_prompt() | |
encoded_image = encode_image_as_base64(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}} | |
] | |
} | |
] | |
outputs = llm.chat(messages, sampling_params=sampling_params) | |
raw_output = outputs[0].outputs[0].text | |
parsed_json = try_extract_json(raw_output) | |
# Apply OCR post-processing | |
ocr_text = extract_all_text_pix2struct(image) | |
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text) | |
return { | |
"json": parsed_json, | |
"raw": raw_output | |
} |