prithivMLmods commited on
Commit
a520e3c
·
verified ·
1 Parent(s): 36ebfe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -47
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import (
4
  AutoProcessor,
5
  TextIteratorStreamer,
6
  AutoModelForImageTextToText,
 
7
  )
8
  from transformers.image_utils import load_image
9
  from threading import Thread
@@ -31,7 +32,10 @@ def progress_bar_html(label: str) -> str:
31
  </style>
32
  '''
33
 
34
- QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or use #prithivMLmods/Qwen2-VL-OCR2-2B-Instruct
 
 
 
35
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
36
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
37
  QV_MODEL_ID,
@@ -39,62 +43,105 @@ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
39
  torch_dtype=torch.float16
40
  ).to("cuda").eval()
41
 
 
42
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
43
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
44
  aya_model = AutoModelForImageTextToText.from_pretrained(
45
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
46
  )
47
 
 
 
 
 
 
 
 
48
  @spaces.GPU
49
  def model_inference(input_dict, history):
50
  text = input_dict["text"].strip()
51
  files = input_dict.get("files", [])
52
 
 
53
  if text.lower().startswith("@aya-vision"):
54
- # Remove the command prefix and trim the prompt.
55
  text_prompt = text[len("@aya-vision"):].strip()
56
  if not files:
57
  yield "Error: Please provide an image for the @aya-vision feature."
58
  return
59
- else:
60
- # For simplicity, use the first provided image.
61
- image = load_image(files[0])
62
- yield progress_bar_html("Processing with Aya-Vision-8b")
63
- messages = [{
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  "role": "user",
65
  "content": [
66
  {"type": "image", "image": image},
67
- {"type": "text", "text": text_prompt},
68
- ],
69
- }]
70
- inputs = aya_processor.apply_chat_template(
71
- messages,
72
- padding=True,
73
- add_generation_prompt=True,
74
- tokenize=True,
75
- return_dict=True,
76
- return_tensors="pt"
77
- ).to(aya_model.device)
78
- # Set up a streamer for Aya-Vision output
79
- streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
80
- generation_kwargs = dict(
81
- inputs,
82
- streamer=streamer,
83
- max_new_tokens=1024,
84
- do_sample=True,
85
- temperature=0.3
86
- )
87
- thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
88
- thread.start()
89
- buffer = ""
90
- for new_text in streamer:
91
- buffer += new_text
92
- buffer = buffer.replace("<|im_end|>", "")
93
- time.sleep(0.01)
94
- yield buffer
95
- return
96
 
97
- # Load images if provided.
98
  if len(files) > 1:
99
  images = [load_image(image) for image in files]
100
  elif len(files) == 1:
@@ -102,7 +149,6 @@ def model_inference(input_dict, history):
102
  else:
103
  images = []
104
 
105
- # Validate input: require both text and (optionally) image(s).
106
  if text == "" and not images:
107
  yield "Error: Please input a query and optionally image(s)."
108
  return
@@ -110,7 +156,6 @@ def model_inference(input_dict, history):
110
  yield "Error: Please input a text query along with the image(s)."
111
  return
112
 
113
- # Prepare messages for the Qwen2-VL model.
114
  messages = [{
115
  "role": "user",
116
  "content": [
@@ -129,11 +174,9 @@ def model_inference(input_dict, history):
129
  padding=True,
130
  ).to("cuda")
131
 
132
- # Set up a streamer for real-time output.
133
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
134
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
135
 
136
- # Start generation in a separate thread.
137
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
138
  thread.start()
139
 
@@ -145,28 +188,36 @@ def model_inference(input_dict, history):
145
  time.sleep(0.01)
146
  yield buffer
147
 
 
148
  examples = [
149
- [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
150
- [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
151
- [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
152
- [{"text": "@aya-vision Describe the photo", "files": ["examples/3.png"]}],
153
  [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
154
  [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
155
  [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
 
156
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
157
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
158
  [{"text": "@aya-vision Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
159
  ]
160
 
 
161
  demo = gr.ChatInterface(
162
  fn=model_inference,
163
- description="# **Multimodal OCR `@aya-vision 'prompt..'`**",
 
 
 
 
 
 
164
  examples=examples,
165
  textbox=gr.MultimodalTextbox(
166
  label="Query Input",
167
  file_types=["image"],
168
  file_count="multiple",
169
- placeholder="By default, it runs Qwen2VL OCR, Tag @aya-vision for Aya Vision 8B"
170
  ),
171
  stop_btn="Stop Generation",
172
  multimodal=True,
 
4
  AutoProcessor,
5
  TextIteratorStreamer,
6
  AutoModelForImageTextToText,
7
+ Gemma3ForConditionalGeneration # new Gemma3 model import
8
  )
9
  from transformers.image_utils import load_image
10
  from threading import Thread
 
32
  </style>
33
  '''
34
 
35
+ ### Load Models & Processors ###
36
+
37
+ # Qwen2VL OCR model (default)
38
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" # or alternate version
39
  qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
40
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
41
  QV_MODEL_ID,
 
43
  torch_dtype=torch.float16
44
  ).to("cuda").eval()
45
 
46
+ # Aya-Vision model (trigger with @aya-vision)
47
  AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
48
  aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
49
  aya_model = AutoModelForImageTextToText.from_pretrained(
50
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
51
  )
52
 
53
+ # Gemma3-4b model (trigger with @gemma3-4b)
54
+ GEMMA3_MODEL_ID = "google/gemma-3-4b-it"
55
+ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
56
+ GEMMA3_MODEL_ID, device_map="auto"
57
+ ).eval()
58
+ gemma3_processor = AutoProcessor.from_pretrained(GEMMA3_MODEL_ID)
59
+
60
  @spaces.GPU
61
  def model_inference(input_dict, history):
62
  text = input_dict["text"].strip()
63
  files = input_dict.get("files", [])
64
 
65
+ # Branch: Aya-Vision (trigger with @aya-vision)
66
  if text.lower().startswith("@aya-vision"):
 
67
  text_prompt = text[len("@aya-vision"):].strip()
68
  if not files:
69
  yield "Error: Please provide an image for the @aya-vision feature."
70
  return
71
+ image = load_image(files[0])
72
+ yield progress_bar_html("Processing with Aya-Vision-8b")
73
+ messages = [{
74
+ "role": "user",
75
+ "content": [
76
+ {"type": "image", "image": image},
77
+ {"type": "text", "text": text_prompt},
78
+ ],
79
+ }]
80
+ inputs = aya_processor.apply_chat_template(
81
+ messages,
82
+ padding=True,
83
+ add_generation_prompt=True,
84
+ tokenize=True,
85
+ return_dict=True,
86
+ return_tensors="pt"
87
+ ).to(aya_model.device)
88
+ streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
89
+ generation_kwargs = dict(
90
+ inputs,
91
+ streamer=streamer,
92
+ max_new_tokens=1024,
93
+ do_sample=True,
94
+ temperature=0.3
95
+ )
96
+ thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
97
+ thread.start()
98
+ buffer = ""
99
+ for new_text in streamer:
100
+ buffer += new_text
101
+ buffer = buffer.replace("<|im_end|>", "")
102
+ time.sleep(0.01)
103
+ yield buffer
104
+ return
105
+
106
+ # Branch: Gemma3-4b (trigger with @gemma3-4b)
107
+ if text.lower().startswith("@gemma3-4b"):
108
+ text_prompt = text[len("@gemma3-4b"):].strip()
109
+ if not files:
110
+ yield "Error: Please provide an image for the @gemma3-4b feature."
111
+ return
112
+ image = load_image(files[0])
113
+ yield progress_bar_html("Processing with Gemma3-4b")
114
+ messages = [
115
+ {
116
+ "role": "system",
117
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
118
+ },
119
+ {
120
  "role": "user",
121
  "content": [
122
  {"type": "image", "image": image},
123
+ {"type": "text", "text": text_prompt}
124
+ ]
125
+ }
126
+ ]
127
+ inputs = gemma3_processor.apply_chat_template(
128
+ messages, add_generation_prompt=True, tokenize=True,
129
+ return_dict=True, return_tensors="pt"
130
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
131
+ input_len = inputs["input_ids"].shape[-1]
132
+ streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True)
133
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512, do_sample=False)
134
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
135
+ thread.start()
136
+ buffer = ""
137
+ for new_text in streamer:
138
+ buffer += new_text
139
+ buffer = buffer.replace("<|im_end|>", "")
140
+ time.sleep(0.01)
141
+ yield buffer
142
+ return
 
 
 
 
 
 
 
 
 
143
 
144
+ # Default Branch: Qwen2-VL OCR (for text query with optional images)
145
  if len(files) > 1:
146
  images = [load_image(image) for image in files]
147
  elif len(files) == 1:
 
149
  else:
150
  images = []
151
 
 
152
  if text == "" and not images:
153
  yield "Error: Please input a query and optionally image(s)."
154
  return
 
156
  yield "Error: Please input a text query along with the image(s)."
157
  return
158
 
 
159
  messages = [{
160
  "role": "user",
161
  "content": [
 
174
  padding=True,
175
  ).to("cuda")
176
 
 
177
  streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
178
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
179
 
 
180
  thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
181
  thread.start()
182
 
 
188
  time.sleep(0.01)
189
  yield buffer
190
 
191
+ # Examples for quick testing.
192
  examples = [
193
+ [{"text": "@gemma3-4b Summarize the letter", "files": ["examples/1.png"]}],
194
+ [{"text": "@gemma3-4b Extract JSON from the image", "files": ["example_images/document.jpg"]}],
195
+ [{"text": "@gemma3-4b Describe the photo", "files": ["examples/3.png"]}],
 
196
  [{"text": "@aya-vision Summarize the full image in detail", "files": ["examples/2.jpg"]}],
197
  [{"text": "@aya-vision Describe this image.", "files": ["example_images/campeones.jpg"]}],
198
  [{"text": "@aya-vision What is this UI about?", "files": ["example_images/s2w_example.png"]}],
199
+ [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
200
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
201
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
202
  [{"text": "@aya-vision Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
203
  ]
204
 
205
+ # Gradio ChatInterface with a multimodal textbox.
206
  demo = gr.ChatInterface(
207
  fn=model_inference,
208
+ description=(
209
+ "# **Multimodal OCR & Vision Features**\n\n"
210
+ "Use the following commands to select a model:\n"
211
+ "- `@aya-vision` for Aya-Vision-8b\n"
212
+ "- `@gemma3-4b` for Gemma3-4b\n\n"
213
+ "Default processing is done with Qwen2VL OCR."
214
+ ),
215
  examples=examples,
216
  textbox=gr.MultimodalTextbox(
217
  label="Query Input",
218
  file_types=["image"],
219
  file_count="multiple",
220
+ placeholder="Enter your text query and attach images if needed. Use @aya-vision or @gemma3-4b to choose a feature."
221
  ),
222
  stop_btn="Stop Generation",
223
  multimodal=True,