Spaces:
Sleeping
Sleeping
File size: 3,969 Bytes
8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 7440d16 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 74e0e2d 8093104 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import json
import re
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
# Set Hugging Face Token from env
hf_token = os.getenv("HF_TOKEN")
# Lazy-load model objects
aya_model = None
aya_processor = None
ocr_model = None
ocr_processor = None
# Load prompt
def load_prompt():
#with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
# return f.read()
return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")
# Try extracting JSON from text
def try_extract_json(text):
if not text or not text.strip():
return None
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find('{')
if start == -1:
return None
brace_count = 0
json_candidate = ''
for i in range(start, len(text)):
char = text[i]
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
json_candidate += char
if brace_count == 0:
break
try:
return json.loads(json_candidate)
except json.JSONDecodeError:
return None
# OCR text from Pix2Struct
def extract_all_text_pix2struct(image: Image.Image):
global ocr_processor, ocr_model
if ocr_processor is None or ocr_model is None:
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_model = ocr_model.to(device)
inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
return output_text.strip()
# Add fallback names if missing
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
if not ocr_text or not json_data:
return json_data
def assign_best_guess(obj):
if not obj.get("name") or obj["name"].strip() == "":
obj["name"] = "(label unknown)"
for evt in json_data.get("events", []):
assign_best_guess(evt)
for gw in json_data.get("gateways", []):
assign_best_guess(gw)
return json_data
# Main inference function
def run_model(image: Image.Image):
global aya_model, aya_processor
if aya_model is None or aya_processor is None:
model_id = "CohereForAI/aya-vision-8b"
aya_processor = AutoProcessor.from_pretrained(model_id)
aya_model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.float16
)
prompt = load_prompt()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
inputs = aya_processor.apply_chat_template(
messages,
padding=True,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(aya_model.device)
gen_tokens = aya_model.generate(
**inputs,
max_new_tokens=5000,
do_sample=True,
temperature=0.3,
)
output_text = aya_processor.tokenizer.decode(
gen_tokens[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
parsed_json = try_extract_json(output_text)
# OCR enhancement
ocr_text = extract_all_text_pix2struct(image)
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
return {
"json": parsed_json,
"raw": output_text
}
|