Spaces:
Sleeping
Sleeping
File size: 4,225 Bytes
6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c c05ab88 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef 6c0c37c 9e1e3ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import json
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from qwen_vl_utils import process_vision_info
# Globals (lazy-loaded at runtime)
qwen_model = None
qwen_processor = None
ocr_model = None
ocr_processor = None
def load_prompt():
#with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
# return f.read()
return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")
def try_extract_json(text):
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find('{')
if start == -1:
return text
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
json_candidate += text[i]
if brace_count == 0:
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return text
def extract_all_text_pix2struct(image: Image.Image):
global ocr_model, ocr_processor
if ocr_model is None or ocr_processor is None:
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_model = ocr_model.to(device)
inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
if not ocr_text or not json_data:
return json_data
def assign_best_guess(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)"
for evt in json_data.get("events", []):
assign_best_guess(evt)
for gw in json_data.get("gateways", []):
assign_best_guess(gw)
return json_data
def run_model(image: Image.Image):
global qwen_model, qwen_processor
if qwen_model is None or qwen_processor is None:
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
# You can enable flash attention here if needed:
# attn_implementation="flash_attention_2"
)
min_pixels = 256 * 28 * 28
max_pixels = 1080 * 28 * 28
qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
min_pixels=min_pixels,
max_pixels=max_pixels
)
prompt = load_prompt()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(qwen_model.device)
generated_ids = qwen_model.generate(**inputs, max_new_tokens=5000)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = qwen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
parsed_json = try_extract_json(output_text)
# OCR post-processing
ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
return {
"json": parsed_json,
"raw": output_text
}
|