ARCQUB commited on
Commit
74e0e2d
·
verified ·
1 Parent(s): 9ac9cd3

Update models/aya_vision.py

Browse files
Files changed (1) hide show
  1. models/aya_vision.py +37 -26
models/aya_vision.py CHANGED
@@ -6,37 +6,32 @@ import torch
6
  from transformers import AutoProcessor, AutoModelForImageTextToText
7
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
 
9
- # Set Hugging Face Token
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
- # Initialize Aya Vision Model
13
- model_id = "CohereForAI/aya-vision-8b"
14
- processor = AutoProcessor.from_pretrained(model_id)
15
- model = AutoModelForImageTextToText.from_pretrained(
16
- model_id, device_map="auto", torch_dtype=torch.float16
17
- )
18
 
19
- # Initialize Pix2Struct OCR Model
20
- ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
21
- ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
22
 
23
  # Load prompt
24
  def load_prompt():
25
  with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
26
  return f.read()
27
 
28
- # Try extracting JSON from model output
 
29
  def try_extract_json(text):
30
  if not text or not text.strip():
31
  return None
32
  try:
33
  return json.loads(text)
34
  except json.JSONDecodeError:
35
- # Try extracting JSON substring by brace balancing
36
  start = text.find('{')
37
  if start == -1:
38
  return None
39
-
40
  brace_count = 0
41
  json_candidate = ''
42
  for i in range(start, len(text)):
@@ -48,26 +43,33 @@ def try_extract_json(text):
48
  json_candidate += char
49
  if brace_count == 0:
50
  break
51
-
52
  try:
53
  return json.loads(json_candidate)
54
  except json.JSONDecodeError:
55
  return None
56
 
57
- # Extract OCR text using Pix2Struct
 
58
  def extract_all_text_pix2struct(image: Image.Image):
59
- inputs = ocr_processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
60
  predictions = ocr_model.generate(**inputs, max_new_tokens=512)
61
  output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
62
  return output_text.strip()
63
 
64
- # Assign event/gateway names from OCR text
 
65
  def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
66
  if not ocr_text or not json_data:
67
  return json_data
68
 
69
- lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
70
-
71
  def assign_best_guess(obj):
72
  if not obj.get("name") or obj["name"].strip() == "":
73
  obj["name"] = "(label unknown)"
@@ -80,8 +82,18 @@ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
80
 
81
  return json_data
82
 
83
- # Run Aya model on image
 
84
  def run_model(image: Image.Image):
 
 
 
 
 
 
 
 
 
85
  prompt = load_prompt()
86
 
87
  messages = [
@@ -94,34 +106,33 @@ def run_model(image: Image.Image):
94
  }
95
  ]
96
 
97
- inputs = processor.apply_chat_template(
98
  messages,
99
  padding=True,
100
  add_generation_prompt=True,
101
  tokenize=True,
102
  return_dict=True,
103
  return_tensors="pt"
104
- ).to(model.device)
105
 
106
- gen_tokens = model.generate(
107
  **inputs,
108
  max_new_tokens=5000,
109
  do_sample=True,
110
  temperature=0.3,
111
  )
112
 
113
- output_text = processor.tokenizer.decode(
114
  gen_tokens[0][inputs.input_ids.shape[1]:],
115
  skip_special_tokens=True
116
  )
117
 
118
  parsed_json = try_extract_json(output_text)
119
 
120
- # Apply OCR post-processing
121
  ocr_text = extract_all_text_pix2struct(image)
122
  parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
123
 
124
- # Return both parsed and raw
125
  return {
126
  "json": parsed_json,
127
  "raw": output_text
 
6
  from transformers import AutoProcessor, AutoModelForImageTextToText
7
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
 
9
+ # Set Hugging Face Token from env
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
+ # Lazy-load model objects
13
+ aya_model = None
14
+ aya_processor = None
 
 
 
15
 
16
+ ocr_model = None
17
+ ocr_processor = None
 
18
 
19
  # Load prompt
20
  def load_prompt():
21
  with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
22
  return f.read()
23
 
24
+
25
+ # Try extracting JSON from text
26
  def try_extract_json(text):
27
  if not text or not text.strip():
28
  return None
29
  try:
30
  return json.loads(text)
31
  except json.JSONDecodeError:
 
32
  start = text.find('{')
33
  if start == -1:
34
  return None
 
35
  brace_count = 0
36
  json_candidate = ''
37
  for i in range(start, len(text)):
 
43
  json_candidate += char
44
  if brace_count == 0:
45
  break
 
46
  try:
47
  return json.loads(json_candidate)
48
  except json.JSONDecodeError:
49
  return None
50
 
51
+
52
+ # OCR text from Pix2Struct
53
  def extract_all_text_pix2struct(image: Image.Image):
54
+ global ocr_processor, ocr_model
55
+
56
+ if ocr_processor is None or ocr_model is None:
57
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
58
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ ocr_model = ocr_model.to(device)
61
+
62
+ inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
63
  predictions = ocr_model.generate(**inputs, max_new_tokens=512)
64
  output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
65
  return output_text.strip()
66
 
67
+
68
+ # Add fallback names if missing
69
  def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
70
  if not ocr_text or not json_data:
71
  return json_data
72
 
 
 
73
  def assign_best_guess(obj):
74
  if not obj.get("name") or obj["name"].strip() == "":
75
  obj["name"] = "(label unknown)"
 
82
 
83
  return json_data
84
 
85
+
86
+ # Main inference function
87
  def run_model(image: Image.Image):
88
+ global aya_model, aya_processor
89
+
90
+ if aya_model is None or aya_processor is None:
91
+ model_id = "CohereForAI/aya-vision-8b"
92
+ aya_processor = AutoProcessor.from_pretrained(model_id)
93
+ aya_model = AutoModelForImageTextToText.from_pretrained(
94
+ model_id, device_map="auto", torch_dtype=torch.float16
95
+ )
96
+
97
  prompt = load_prompt()
98
 
99
  messages = [
 
106
  }
107
  ]
108
 
109
+ inputs = aya_processor.apply_chat_template(
110
  messages,
111
  padding=True,
112
  add_generation_prompt=True,
113
  tokenize=True,
114
  return_dict=True,
115
  return_tensors="pt"
116
+ ).to(aya_model.device)
117
 
118
+ gen_tokens = aya_model.generate(
119
  **inputs,
120
  max_new_tokens=5000,
121
  do_sample=True,
122
  temperature=0.3,
123
  )
124
 
125
+ output_text = aya_processor.tokenizer.decode(
126
  gen_tokens[0][inputs.input_ids.shape[1]:],
127
  skip_special_tokens=True
128
  )
129
 
130
  parsed_json = try_extract_json(output_text)
131
 
132
+ # OCR enhancement
133
  ocr_text = extract_all_text_pix2struct(image)
134
  parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
135
 
 
136
  return {
137
  "json": parsed_json,
138
  "raw": output_text