prithivMLmods commited on
Commit
64f9a07
·
verified ·
1 Parent(s): 5e7d4ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -98
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import (
4
  AutoProcessor,
5
  TextIteratorStreamer,
6
  AutoModelForImageTextToText,
7
- Gemma3ForConditionalGeneration # new Gemma3 model import
8
  )
9
  from transformers.image_utils import load_image
10
  from threading import Thread
@@ -32,10 +31,7 @@ def progress_bar_html(label: str) -> str:
32
  </style>
33
  '''
34
 
35
- ### Load Models & Processors ###
36
-
37
- # Qwen2VL OCR model (default)
38
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or alternate version
39
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
40
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
41
  QV_MODEL_ID,
@@ -43,105 +39,62 @@ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
43
  torch_dtype=torch.float16
44
  ).to("cuda").eval()
45
 
46
- # Aya-Vision model (trigger with @aya-vision)
47
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
48
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
49
  aya_model = AutoModelForImageTextToText.from_pretrained(
50
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
51
  )
52
 
53
- # Gemma3-4b model (trigger with @gemma3-4b)
54
- GEMMA3_MODEL_ID = "google/gemma-3-4b-it"
55
- gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
56
- GEMMA3_MODEL_ID, device_map="auto"
57
- ).eval()
58
- gemma3_processor = AutoProcessor.from_pretrained(GEMMA3_MODEL_ID)
59
-
60
  @spaces.GPU
61
  def model_inference(input_dict, history):
62
  text = input_dict["text"].strip()
63
  files = input_dict.get("files", [])
64
 
65
- # Branch: Aya-Vision (trigger with @aya-vision)
66
  if text.lower().startswith("@aya-vision"):
 
67
  text_prompt = text[len("@aya-vision"):].strip()
68
  if not files:
69
  yield "Error: Please provide an image for the @aya-vision feature."
70
  return
71
- image = load_image(files[0])
72
- yield progress_bar_html("Processing with Aya-Vision-8b")
73
- messages = [{
74
- "role": "user",
75
- "content": [
76
- {"type": "image", "image": image},
77
- {"type": "text", "text": text_prompt},
78
- ],
79
- }]
80
- inputs = aya_processor.apply_chat_template(
81
- messages,
82
- padding=True,
83
- add_generation_prompt=True,
84
- tokenize=True,
85
- return_dict=True,
86
- return_tensors="pt"
87
- ).to(aya_model.device)
88
- streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
89
- generation_kwargs = dict(
90
- inputs,
91
- streamer=streamer,
92
- max_new_tokens=1024,
93
- do_sample=True,
94
- temperature=0.3
95
- )
96
- thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
97
- thread.start()
98
- buffer = ""
99
- for new_text in streamer:
100
- buffer += new_text
101
- buffer = buffer.replace("<|im_end|>", "")
102
- time.sleep(0.01)
103
- yield buffer
104
- return
105
-
106
- # Branch: Gemma3-4b (trigger with @gemma3-4b)
107
- if text.lower().startswith("@gemma3-4b"):
108
- text_prompt = text[len("@gemma3-4b"):].strip()
109
- if not files:
110
- yield "Error: Please provide an image for the @gemma3-4b feature."
111
- return
112
- image = load_image(files[0])
113
- yield progress_bar_html("Processing with Gemma3-4b")
114
- messages = [
115
- {
116
- "role": "system",
117
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
118
- },
119
- {
120
  "role": "user",
121
  "content": [
122
  {"type": "image", "image": image},
123
- {"type": "text", "text": text_prompt}
124
- ]
125
- }
126
- ]
127
- inputs = gemma3_processor.apply_chat_template(
128
- messages, add_generation_prompt=True, tokenize=True,
129
- return_dict=True, return_tensors="pt"
130
- ).to(gemma3_model.device, dtype=torch.bfloat16)
131
- input_len = inputs["input_ids"].shape[-1]
132
- streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True)
133
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512, do_sample=False)
134
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
135
- thread.start()
136
- buffer = ""
137
- for new_text in streamer:
138
- buffer += new_text
139
- buffer = buffer.replace("<|im_end|>", "")
140
- time.sleep(0.01)
141
- yield buffer
142
- return
 
 
 
 
 
 
 
 
 
143
 
144
- # Default Branch: Qwen2-VL OCR (for text query with optional images)
145
  if len(files) > 1:
146
  images = [load_image(image) for image in files]
147
  elif len(files) == 1:
@@ -149,6 +102,7 @@ def model_inference(input_dict, history):
149
  else:
150
  images = []
151
 
 
152
  if text == "" and not images:
153
  yield "Error: Please input a query and optionally image(s)."
154
  return
@@ -156,6 +110,7 @@ def model_inference(input_dict, history):
156
  yield "Error: Please input a text query along with the image(s)."
157
  return
158
 
 
159
  messages = [{
160
  "role": "user",
161
  "content": [
@@ -174,9 +129,11 @@ def model_inference(input_dict, history):
174
  padding=True,
175
  ).to("cuda")
176
 
 
177
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
178
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
179
 
 
180
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
181
  thread.start()
182
 
@@ -188,36 +145,28 @@ def model_inference(input_dict, history):
188
  time.sleep(0.01)
189
  yield buffer
190
 
191
- # Examples for quick testing.
192
  examples = [
193
- [{"text": "@gemma3-4b Summarize the letter", "files": ["examples/1.png"]}],
194
- [{"text": "@gemma3-4b Extract JSON from the image", "files": ["example_images/document.jpg"]}],
195
- [{"text": "@gemma3-4b Describe the photo", "files": ["examples/3.png"]}],
 
196
  [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
197
  [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
198
  [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
199
- [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
200
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
201
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
202
  [{"text": "@aya-vision Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
203
  ]
204
 
205
- # Gradio ChatInterface with a multimodal textbox.
206
  demo = gr.ChatInterface(
207
  fn=model_inference,
208
- description=(
209
- "# **Multimodal OCR & Vision Features**\n\n"
210
- "Use the following commands to select a model:\n"
211
- "- `@aya-vision` for Aya-Vision-8b\n"
212
- "- `@gemma3-4b` for Gemma3-4b\n\n"
213
- "Default processing is done with Qwen2VL OCR."
214
- ),
215
  examples=examples,
216
  textbox=gr.MultimodalTextbox(
217
  label="Query Input",
218
  file_types=["image"],
219
  file_count="multiple",
220
- placeholder="Enter your text query and attach images if needed. Use @aya-vision or @gemma3-4b to choose a feature."
221
  ),
222
  stop_btn="Stop Generation",
223
  multimodal=True,
 
4
  AutoProcessor,
5
  TextIteratorStreamer,
6
  AutoModelForImageTextToText,
 
7
  )
8
  from transformers.image_utils import load_image
9
  from threading import Thread
 
31
  </style>
32
  '''
33
 
34
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or use #prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
 
 
 
35
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
36
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
37
  QV_MODEL_ID,
 
39
  torch_dtype=torch.float16
40
  ).to("cuda").eval()
41
 
 
42
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
43
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
44
  aya_model = AutoModelForImageTextToText.from_pretrained(
45
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
46
  )
47
 
 
 
 
 
 
 
 
48
  @spaces.GPU
49
  def model_inference(input_dict, history):
50
  text = input_dict["text"].strip()
51
  files = input_dict.get("files", [])
52
 
 
53
  if text.lower().startswith("@aya-vision"):
54
+ # Remove the command prefix and trim the prompt.
55
  text_prompt = text[len("@aya-vision"):].strip()
56
  if not files:
57
  yield "Error: Please provide an image for the @aya-vision feature."
58
  return
59
+ else:
60
+ # For simplicity, use the first provided image.
61
+ image = load_image(files[0])
62
+ yield progress_bar_html("Processing with Aya-Vision-8b")
63
+ messages = [{
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  "role": "user",
65
  "content": [
66
  {"type": "image", "image": image},
67
+ {"type": "text", "text": text_prompt},
68
+ ],
69
+ }]
70
+ inputs = aya_processor.apply_chat_template(
71
+ messages,
72
+ padding=True,
73
+ add_generation_prompt=True,
74
+ tokenize=True,
75
+ return_dict=True,
76
+ return_tensors="pt"
77
+ ).to(aya_model.device)
78
+ # Set up a streamer for Aya-Vision output
79
+ streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
80
+ generation_kwargs = dict(
81
+ inputs,
82
+ streamer=streamer,
83
+ max_new_tokens=1024,
84
+ do_sample=True,
85
+ temperature=0.3
86
+ )
87
+ thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
88
+ thread.start()
89
+ buffer = ""
90
+ for new_text in streamer:
91
+ buffer += new_text
92
+ buffer = buffer.replace("<|im_end|>", "")
93
+ time.sleep(0.01)
94
+ yield buffer
95
+ return
96
 
97
+ # Load images if provided.
98
  if len(files) > 1:
99
  images = [load_image(image) for image in files]
100
  elif len(files) == 1:
 
102
  else:
103
  images = []
104
 
105
+ # Validate input: require both text and (optionally) image(s).
106
  if text == "" and not images:
107
  yield "Error: Please input a query and optionally image(s)."
108
  return
 
110
  yield "Error: Please input a text query along with the image(s)."
111
  return
112
 
113
+ # Prepare messages for the Qwen2-VL model.
114
  messages = [{
115
  "role": "user",
116
  "content": [
 
129
  padding=True,
130
  ).to("cuda")
131
 
132
+ # Set up a streamer for real-time output.
133
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
134
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
135
 
136
+ # Start generation in a separate thread.
137
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
138
  thread.start()
139
 
 
145
  time.sleep(0.01)
146
  yield buffer
147
 
 
148
  examples = [
149
+ [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
150
+ [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
151
+ [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
152
+ [{"text": "@aya-vision Describe the photo", "files": ["examples/3.png"]}],
153
  [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
154
  [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
155
  [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
 
156
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
157
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
158
  [{"text": "@aya-vision Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
159
  ]
160
 
 
161
  demo = gr.ChatInterface(
162
  fn=model_inference,
163
+
 
 
 
 
 
 
164
  examples=examples,
165
  textbox=gr.MultimodalTextbox(
166
  label="Query Input",
167
  file_types=["image"],
168
  file_count="multiple",
169
+ placeholder="By default, it runs Qwen2VL OCR, Tag @aya-vision for Aya Vision 8B"
170
  ),
171
  stop_btn="Stop Generation",
172
  multimodal=True,