prithivMLmods commited on
Commit
5633a75
·
verified ·
1 Parent(s): a4cab0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -5,6 +5,7 @@ from threading import Thread
5
  import time
6
  import torch
7
  import spaces
 
8
 
9
  # Fine-tuned for OCR-based tasks from Qwen's [ Qwen/Qwen2-VL-2B-Instruct ]
10
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
@@ -22,18 +23,25 @@ def model_inference(input_dict, history):
22
 
23
  # Load images if provided
24
  if len(files) > 1:
25
- images = [load_image(image) for image in files]
 
26
  elif len(files) == 1:
27
- images = [load_image(files[0])]
 
 
 
 
 
28
  else:
29
  images = []
 
30
 
31
  # Validate input
32
- if text == "" and not images:
33
- gr.Error("Please input a query and optionally image(s).")
34
  return
35
- if text == "" and images:
36
- gr.Error("Please input a text query along with the image(s).")
37
  return
38
 
39
  # Prepare messages for the model
@@ -42,18 +50,24 @@ def model_inference(input_dict, history):
42
  "role": "user",
43
  "content": [
44
  *[{"type": "image", "image": image} for image in images],
 
45
  {"type": "text", "text": text},
46
  ],
47
  }
48
  ]
49
 
 
 
 
50
  # Apply chat template and process inputs
51
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
52
  inputs = processor(
53
  text=[prompt],
54
- images=images if images else None,
55
- return_tensors="pt",
56
  padding=True,
 
 
57
  ).to("cuda")
58
 
59
  # Set up streamer for real-time output
@@ -76,7 +90,6 @@ def model_inference(input_dict, history):
76
 
77
  # Example inputs
78
  examples = [
79
-
80
  [{"text": "Extract JSON from the image", "files": ["example_images/document.jpg"]}],
81
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
82
  [{"text": "Describe the photo", "files": ["examples/3.png"]}],
@@ -87,14 +100,14 @@ examples = [
87
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
88
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
89
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
90
-
91
  ]
92
 
93
  demo = gr.ChatInterface(
94
  fn=model_inference,
95
  description="# **Multimodal OCR**",
96
  examples=examples,
97
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
98
  stop_btn="Stop Generation",
99
  multimodal=True,
100
  cache_examples=False,
 
5
  import time
6
  import torch
7
  import spaces
8
+ from qwen_vl_utils import process_vision_info
9
 
10
  # Fine-tuned for OCR-based tasks from Qwen's [ Qwen/Qwen2-VL-2B-Instruct ]
11
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
 
23
 
24
  # Load images if provided
25
  if len(files) > 1:
26
+ images = [load_image(image) for image in files if image.endswith(('png', 'jpg', 'jpeg'))]
27
+ videos = [video for video in files if video.endswith(('mp4', 'avi', 'mov'))]
28
  elif len(files) == 1:
29
+ if files[0].endswith(('png', 'jpg', 'jpeg')):
30
+ images = [load_image(files[0])]
31
+ videos = []
32
+ else:
33
+ images = []
34
+ videos = [files[0]]
35
  else:
36
  images = []
37
+ videos = []
38
 
39
  # Validate input
40
+ if text == "" and not images and not videos:
41
+ gr.Error("Please input a query and optionally image(s) or video(s).")
42
  return
43
+ if text == "" and (images or videos):
44
+ gr.Error("Please input a text query along with the image(s) or video(s).")
45
  return
46
 
47
  # Prepare messages for the model
 
50
  "role": "user",
51
  "content": [
52
  *[{"type": "image", "image": image} for image in images],
53
+ *[{"type": "video", "video": video} for video in videos],
54
  {"type": "text", "text": text},
55
  ],
56
  }
57
  ]
58
 
59
+ # Process vision info (images and videos)
60
+ image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
61
+
62
  # Apply chat template and process inputs
63
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
64
  inputs = processor(
65
  text=[prompt],
66
+ images=image_inputs,
67
+ videos=video_inputs,
68
  padding=True,
69
+ return_tensors="pt",
70
+ **video_kwargs,
71
  ).to("cuda")
72
 
73
  # Set up streamer for real-time output
 
90
 
91
  # Example inputs
92
  examples = [
 
93
  [{"text": "Extract JSON from the image", "files": ["example_images/document.jpg"]}],
94
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
95
  [{"text": "Describe the photo", "files": ["examples/3.png"]}],
 
100
  [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
101
  [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
102
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
103
+ [{"text": "Describe the video.", "files": ["example_videos/sample.mp4"]}],
104
  ]
105
 
106
  demo = gr.ChatInterface(
107
  fn=model_inference,
108
  description="# **Multimodal OCR**",
109
  examples=examples,
110
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
111
  stop_btn="Stop Generation",
112
  multimodal=True,
113
  cache_examples=False,