ARCQUB's picture
Upload folder using huggingface_hub
6c0c37c verified
raw
history blame
3.66 kB
import os
import json
import base64
from PIL import Image
from vllm import LLM
from vllm.sampling_params import SamplingParams
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
# Optional: Replace with your Hugging Face token or use environment variable
hf_token = os.getenv("HF_TOKEN")
Image.MAX_IMAGE_PIXELS = None
# Initialize Pixtral model
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=5000)
llm = LLM(model=model_name, tokenizer_mode="mistral", dtype="bfloat16", max_model_len=30000)
# Initialize Pix2Struct OCR model
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
# Load prompt from file
def load_prompt():
with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
return f.read()
# Extract structured JSON from text
def try_extract_json(text):
if not text or not text.strip():
return None
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find('{')
if start == -1:
return None
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
json_candidate += text[i]
if brace_count == 0:
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return None
# Base64 encode image
def encode_image_as_base64(pil_image):
from io import BytesIO
buffer = BytesIO()
pil_image.save(buffer, format="JPEG")
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded
# Extract OCR text using Pix2Struct
def extract_all_text_pix2struct(image: Image.Image):
inputs = ocr_processor(images=image, return_tensors="pt")
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
return output_text.strip()
# Assign event/gateway names from OCR text
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
if not ocr_text or not json_data:
return json_data
lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
def assign_best_guess(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)"
for evt in json_data.get("events", []):
assign_best_guess(evt)
for gw in json_data.get("gateways", []):
assign_best_guess(gw)
return json_data
# Run model
def run_model(image: Image.Image):
prompt = load_prompt()
encoded_image = encode_image_as_base64(image)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
]
}
]
outputs = llm.chat(messages, sampling_params=sampling_params)
raw_output = outputs[0].outputs[0].text
parsed_json = try_extract_json(raw_output)
# Apply OCR post-processing
ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
return {
"json": parsed_json,
"raw": raw_output
}