ARCQUB's picture
Update models/qwen.py
c05ab88 verified
import os
import json
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from qwen_vl_utils import process_vision_info
# Globals (lazy-loaded at runtime)
qwen_model = None
qwen_processor = None
ocr_model = None
ocr_processor = None
def load_prompt():
#with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
# return f.read()
return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")
def try_extract_json(text):
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find('{')
if start == -1:
return text
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
json_candidate += text[i]
if brace_count == 0:
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return text
def extract_all_text_pix2struct(image: Image.Image):
global ocr_model, ocr_processor
if ocr_model is None or ocr_processor is None:
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_model = ocr_model.to(device)
inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
if not ocr_text or not json_data:
return json_data
def assign_best_guess(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)"
for evt in json_data.get("events", []):
assign_best_guess(evt)
for gw in json_data.get("gateways", []):
assign_best_guess(gw)
return json_data
def run_model(image: Image.Image):
global qwen_model, qwen_processor
if qwen_model is None or qwen_processor is None:
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
# You can enable flash attention here if needed:
# attn_implementation="flash_attention_2"
)
min_pixels = 256 * 28 * 28
max_pixels = 1080 * 28 * 28
qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
min_pixels=min_pixels,
max_pixels=max_pixels
)
prompt = load_prompt()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(qwen_model.device)
generated_ids = qwen_model.generate(**inputs, max_new_tokens=5000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = qwen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
parsed_json = try_extract_json(output_text)
# OCR post-processing
ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
return {
"json": parsed_json,
"raw": output_text
}