prithivMLmods commited on
Commit
a5d07a8
·
verified ·
1 Parent(s): f57c7ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -41
app.py CHANGED
@@ -6,10 +6,55 @@ import time
6
  import torch
7
  import spaces
8
 
9
- MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" #else ; MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
10
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
11
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
12
- MODEL_ID,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  trust_remote_code=True,
14
  torch_dtype=torch.bfloat16
15
  ).to("cuda").eval()
@@ -19,73 +64,89 @@ def model_inference(input_dict, history):
19
  text = input_dict["text"]
20
  files = input_dict["files"]
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Load images if provided
23
- if len(files) > 1:
24
- images = [load_image(image) for image in files]
25
- elif len(files) == 1:
26
- images = [load_image(files[0])]
 
 
 
 
 
 
27
  else:
28
  images = []
29
 
30
- # Validate input
31
- if text == "" and not images:
32
- gr.Error("Please input a query and optionally image(s).")
33
- return
34
- if text == "" and images:
35
- gr.Error("Please input a text query along with the image(s).")
36
  return
37
 
38
  # Prepare messages for the model
39
- messages = [
40
- {
41
- "role": "user",
42
- "content": [
43
- *[{"type": "image", "image": image} for image in images],
44
- {"type": "text", "text": text},
45
- ],
46
- }
47
- ]
48
-
49
- # Apply chat template and process inputs
50
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
51
- inputs = processor(
52
  text=[prompt],
53
  images=images if images else None,
54
  return_tensors="pt",
55
  padding=True,
56
  ).to("cuda")
57
 
58
- # Set up streamer for real-time output
59
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
60
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
61
 
62
  # Start generation in a separate thread
63
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
64
  thread.start()
65
 
66
- # Stream the output
 
 
67
  buffer = ""
68
- yield "Thinking..."
69
  for new_text in streamer:
70
  buffer += new_text
71
  time.sleep(0.01)
72
  yield buffer
73
 
74
-
75
- # Example inputs
76
  examples = [
77
- [{"text": "Describe the document?", "files": ["example_images/document.jpg"]}],
78
- [{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
79
- [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
80
- [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
81
-
82
  ]
83
 
84
  demo = gr.ChatInterface(
85
  fn=model_inference,
86
- description="# **Qwen2.5-VL-7B-Instruct**",
 
87
  examples=examples,
88
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
89
  stop_btn="Stop Generation",
90
  multimodal=True,
91
  cache_examples=False,
 
6
  import torch
7
  import spaces
8
 
9
+ DESCRIPTION = """
10
+ # Qwen2.5-VL-3B/7B-Instruct
11
+ """
12
+
13
+ css = '''
14
+ h1 {
15
+ text-align: center;
16
+ display: block;
17
+ }
18
+ #duplicate-button {
19
+ margin: auto;
20
+ color: #fff;
21
+ background: #1565c0;
22
+ border-radius: 100vh;
23
+ }
24
+ '''
25
+
26
+ # Define an animated progress bar HTML snippet
27
+ def progress_bar_html(label: str) -> str:
28
+ return f'''
29
+ <div style="display: flex; align-items: center;">
30
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
31
+ <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
32
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
33
+ </div>
34
+ </div>
35
+ <style>
36
+ @keyframes loading {{
37
+ 0% {{ transform: translateX(-100%); }}
38
+ 100% {{ transform: translateX(100%); }}
39
+ }}
40
+ </style>
41
+ '''
42
+
43
+ # Model IDs for 3B and 7B variants
44
+ MODEL_ID_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
45
+ MODEL_ID_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
46
+
47
+ # Load the processor and models for both versions
48
+ processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
49
+ model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50
+ MODEL_ID_3B,
51
+ trust_remote_code=True,
52
+ torch_dtype=torch.bfloat16
53
+ ).to("cuda").eval()
54
+
55
+ processor_7b = AutoProcessor.from_pretrained(MODEL_ID_7B, trust_remote_code=True)
56
+ model_7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
+ MODEL_ID_7B,
58
  trust_remote_code=True,
59
  torch_dtype=torch.bfloat16
60
  ).to("cuda").eval()
 
64
  text = input_dict["text"]
65
  files = input_dict["files"]
66
 
67
+ # Determine which model to use based on the prefix tag
68
+ if text.lower().startswith("@3b"):
69
+ yield progress_bar_html("processing with Qwen2.5-VL-3B-Instruct")
70
+ selected_model = model_3b
71
+ selected_processor = processor_3b
72
+ text = text[len("@3b"):].strip()
73
+ elif text.lower().startswith("@7b"):
74
+ yield progress_bar_html("processing with Qwen2.5-VL-7B-Instruct")
75
+ selected_model = model_7b
76
+ selected_processor = processor_7b
77
+ text = text[len("@7b"):].strip()
78
+ else:
79
+ yield "Error: Please prefix your query with @3b or @7b to select the model."
80
+ return
81
+
82
  # Load images if provided
83
+ if files:
84
+ if isinstance(files, list):
85
+ if len(files) > 1:
86
+ images = [load_image(image) for image in files]
87
+ elif len(files) == 1:
88
+ images = [load_image(files[0])]
89
+ else:
90
+ images = []
91
+ else:
92
+ images = [load_image(files)]
93
  else:
94
  images = []
95
 
96
+ # Validate input: text query is required
97
+ if text == "":
98
+ yield "Error: Please input a text query along with the image(s) if any."
 
 
 
99
  return
100
 
101
  # Prepare messages for the model
102
+ messages = [{
103
+ "role": "user",
104
+ "content": [
105
+ *[{"type": "image", "image": image} for image in images],
106
+ {"type": "text", "text": text},
107
+ ]
108
+ }]
109
+
110
+ # Apply the chat template and process the inputs
111
+ prompt = selected_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
112
+ inputs = selected_processor(
 
 
113
  text=[prompt],
114
  images=images if images else None,
115
  return_tensors="pt",
116
  padding=True,
117
  ).to("cuda")
118
 
119
+ # Set up a streamer for real-time text generation
120
+ streamer = TextIteratorStreamer(selected_processor, skip_prompt=True, skip_special_tokens=True)
121
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
122
 
123
  # Start generation in a separate thread
124
+ thread = Thread(target=selected_model.generate, kwargs=generation_kwargs)
125
  thread.start()
126
 
127
+ # Yield an animated progress message
128
+ yield progress_bar_html("Thinking...")
129
+
130
  buffer = ""
 
131
  for new_text in streamer:
132
  buffer += new_text
133
  time.sleep(0.01)
134
  yield buffer
135
 
136
+ # Example inputs with model prefixes
 
137
  examples = [
138
+ [{"text": "@3b Describe the document?", "files": ["example_images/document.jpg"]}],
139
+ [{"text": "@7b What does this say?", "files": ["example_images/math.jpg"]}],
140
+ [{"text": "@3b What is this UI about?", "files": ["example_images/s2w_example.png"]}],
141
+ [{"text": "@7b Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
 
142
  ]
143
 
144
  demo = gr.ChatInterface(
145
  fn=model_inference,
146
+ description=DESCRIPTION,
147
+ css=css,
148
  examples=examples,
149
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="Use Tags @3b / @7b to trigger the models"),
150
  stop_btn="Stop Generation",
151
  multimodal=True,
152
  cache_examples=False,