prithivMLmods commited on
Commit
3f6a788
Β·
verified Β·
1 Parent(s): 5e4e6ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -43
app.py CHANGED
@@ -1,84 +1,148 @@
1
  import gradio as gr
2
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
 
 
 
 
 
3
  from transformers.image_utils import load_image
4
  from threading import Thread
5
  import time
6
  import torch
7
  import spaces
 
 
 
8
 
9
- # Fine-tuned for OCR-based tasks from Qwen's [ Qwen/Qwen2-VL-2B-Instruct ]
10
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
11
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
12
- model = Qwen2VLForConditionalGeneration.from_pretrained(
13
- MODEL_ID,
 
 
14
  trust_remote_code=True,
15
  torch_dtype=torch.float16
16
  ).to("cuda").eval()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @spaces.GPU
19
  def model_inference(input_dict, history):
20
- text = input_dict["text"]
21
- files = input_dict["files"]
22
-
23
- # Load images if provided
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  if len(files) > 1:
25
  images = [load_image(image) for image in files]
26
  elif len(files) == 1:
27
  images = [load_image(files[0])]
28
  else:
29
  images = []
30
-
31
- # Validate input
32
  if text == "" and not images:
33
- gr.Error("Please input a query and optionally image(s).")
34
  return
35
  if text == "" and images:
36
- gr.Error("Please input a text query along with the image(s).")
37
  return
38
 
39
- # Prepare messages for the model
40
- messages = [
41
- {
42
- "role": "user",
43
- "content": [
44
- *[{"type": "image", "image": image} for image in images],
45
- {"type": "text", "text": text},
46
- ],
47
- }
48
- ]
49
-
50
- # Apply chat template and process inputs
51
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
52
- inputs = processor(
53
  text=[prompt],
54
  images=images if images else None,
55
  return_tensors="pt",
56
  padding=True,
57
  ).to("cuda")
58
-
59
- # Set up streamer for real-time output
60
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
61
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
62
-
63
- # Start generation in a separate thread
64
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
  thread.start()
66
-
67
- # Stream the output
68
  buffer = ""
69
  yield "Thinking..."
70
  for new_text in streamer:
71
  buffer += new_text
72
- # Remove <|im_end|> or similar tokens from the output
73
  buffer = buffer.replace("<|im_end|>", "")
74
  time.sleep(0.01)
75
  yield buffer
76
 
77
- # Example inputs
 
 
78
  examples = [
79
-
80
- [{"text": "Extract JSON from the image", "files": ["example_images/document.jpg"]}],
81
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
82
  [{"text": "Describe the photo", "files": ["examples/3.png"]}],
83
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
84
  [{"text": "Summarize the full image in detail", "files": ["examples/2.jpg"]}],
@@ -87,12 +151,12 @@ examples = [
87
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
88
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
89
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
90
-
91
  ]
92
 
 
93
  demo = gr.ChatInterface(
94
  fn=model_inference,
95
- description="# **Multimodal OCR**",
96
  examples=examples,
97
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
98
  stop_btn="Stop Generation",
 
1
  import gradio as gr
2
+ from transformers import (
3
+ Qwen2VLForConditionalGeneration,
4
+ AutoProcessor,
5
+ TextIteratorStreamer,
6
+ AutoModelForImageTextToText,
7
+ )
8
  from transformers.image_utils import load_image
9
  from threading import Thread
10
  import time
11
  import torch
12
  import spaces
13
+ from PIL import Image
14
+ import requests
15
+ from io import BytesIO
16
 
17
+ # -------------------------
18
+ # Qwen2-VL Model for OCR-based tasks
19
+ # -------------------------
20
+ QV_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
21
+ qwen_processor = AutoProcessor.from_pretrained(QV_MODEL_ID, trust_remote_code=True)
22
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
23
+ QV_MODEL_ID,
24
  trust_remote_code=True,
25
  torch_dtype=torch.float16
26
  ).to("cuda").eval()
27
 
28
+ # -------------------------
29
+ # Aya-Vision Model for image-text tasks (@aya-vision)
30
+ # -------------------------
31
+ AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
32
+ aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
33
+ aya_model = AutoModelForImageTextToText.from_pretrained(
34
+ AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
35
+ )
36
+
37
+ def aya_vision_chat(image, text_prompt):
38
+ # If image is provided as a URL, load it via requests.
39
+ if isinstance(image, str):
40
+ response = requests.get(image)
41
+ image = Image.open(BytesIO(response.content))
42
+
43
+ messages = [{
44
+ "role": "user",
45
+ "content": [
46
+ {"type": "image", "image": image},
47
+ {"type": "text", "text": text_prompt},
48
+ ],
49
+ }]
50
+
51
+ inputs = aya_processor.apply_chat_template(
52
+ messages,
53
+ padding=True,
54
+ add_generation_prompt=True,
55
+ tokenize=True,
56
+ return_dict=True,
57
+ return_tensors="pt"
58
+ ).to(aya_model.device)
59
+
60
+ gen_tokens = aya_model.generate(
61
+ **inputs, max_new_tokens=300, do_sample=True, temperature=0.3
62
+ )
63
+
64
+ # Decode only the newly generated tokens.
65
+ response_text = aya_processor.tokenizer.decode(
66
+ gen_tokens[0][inputs.input_ids.shape[1]:],
67
+ skip_special_tokens=True
68
+ )
69
+ return response_text
70
+
71
  @spaces.GPU
72
  def model_inference(input_dict, history):
73
+ text = input_dict["text"].strip()
74
+ files = input_dict.get("files", [])
75
+
76
+ if text.lower().startswith("@aya-vision"):
77
+ # Remove the command prefix and trim the prompt.
78
+ text_prompt = text[len("@aya-vision"):].strip()
79
+ if not files:
80
+ yield "Error: Please provide an image for the @aya-vision feature."
81
+ return
82
+ else:
83
+ # For simplicity, use the first provided image.
84
+ image = load_image(files[0])
85
+ yield "Processing with Aya-Vision β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–’β–’β–’ 69%"
86
+ response_text = aya_vision_chat(image, text_prompt)
87
+ yield response_text
88
+ return
89
+ # Load images if provided.
90
  if len(files) > 1:
91
  images = [load_image(image) for image in files]
92
  elif len(files) == 1:
93
  images = [load_image(files[0])]
94
  else:
95
  images = []
96
+
97
+ # Validate input: require both text and (optionally) image(s).
98
  if text == "" and not images:
99
+ yield "Error: Please input a query and optionally image(s)."
100
  return
101
  if text == "" and images:
102
+ yield "Error: Please input a text query along with the image(s)."
103
  return
104
 
105
+ # Prepare messages for the Qwen2-VL model.
106
+ messages = [{
107
+ "role": "user",
108
+ "content": [
109
+ *[{"type": "image", "image": image} for image in images],
110
+ {"type": "text", "text": text},
111
+ ],
112
+ }]
113
+
114
+ prompt = qwen_processor.apply_chat_template(
115
+ messages, tokenize=False, add_generation_prompt=True
116
+ )
117
+ inputs = qwen_processor(
 
118
  text=[prompt],
119
  images=images if images else None,
120
  return_tensors="pt",
121
  padding=True,
122
  ).to("cuda")
123
+
124
+ # Set up a streamer for real-time output.
125
+ streamer = TextIteratorStreamer(qwen_processor, skip_prompt=True, skip_special_tokens=True)
126
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
127
+
128
+ # Start generation in a separate thread.
129
+ thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
130
  thread.start()
131
+
 
132
  buffer = ""
133
  yield "Thinking..."
134
  for new_text in streamer:
135
  buffer += new_text
 
136
  buffer = buffer.replace("<|im_end|>", "")
137
  time.sleep(0.01)
138
  yield buffer
139
 
140
+ # -------------------------
141
+ # Example inputs for the combined interface
142
+ # -------------------------
143
  examples = [
144
+ [{"text": "@aya-vision Extract JSON from the image", "files": ["example_images/document.jpg"]}],
145
+ [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
 
146
  [{"text": "Describe the photo", "files": ["examples/3.png"]}],
147
  [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
148
  [{"text": "Summarize the full image in detail", "files": ["examples/2.jpg"]}],
 
151
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
152
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
153
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
 
154
  ]
155
 
156
+ # Build the Gradio ChatInterface.
157
  demo = gr.ChatInterface(
158
  fn=model_inference,
159
+ description="# **Multimodal OCR with @aya-vision Feature**",
160
  examples=examples,
161
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
162
  stop_btn="Stop Generation",