whyumesh commited on
Commit
a8a3b01
·
verified ·
1 Parent(s): 478042e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -103
app.py CHANGED
@@ -6,126 +6,147 @@ import cv2
6
  import numpy as np
7
  import gradio as gr
8
 
 
 
 
 
9
  # Load the model and processor
10
  def load_model():
11
- model = Qwen2VLForConditionalGeneration.from_pretrained(
12
- "Qwen/Qwen2-VL-2B-Instruct",
13
- torch_dtype=torch.float16
14
- ).to("cuda" if torch.cuda.is_available() else "cpu")
15
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
16
- return model, processor
17
-
18
- model, processor = load_model()
 
 
 
 
 
 
 
 
19
 
20
  def process_image(image):
21
- messages = [
22
- {
23
- "role": "user",
24
- "content": [
25
- {"type": "image", "image": image},
26
- {"type": "text", "text": "Describe this image."},
27
- ],
28
- }
29
- ]
30
-
31
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
- image_inputs, video_inputs = process_vision_info(messages)
33
-
34
- inputs = processor(
35
- text=[text],
36
- images=image_inputs,
37
- videos=video_inputs,
38
- padding=True,
39
- return_tensors="pt",
40
- ).to(model.device)
41
-
42
- with torch.no_grad():
43
- generated_ids = model.generate(**inputs, max_new_tokens=256)
44
- generated_ids_trimmed = [
45
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
46
- ]
47
- output_text = processor.batch_decode(
48
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
49
- )
50
-
51
- return output_text[0]
 
 
 
52
 
53
  def process_video(video_path, max_frames=16, frame_interval=30, max_resolution=224):
54
- cap = cv2.VideoCapture(video_path)
55
- frames = []
56
- frame_count = 0
57
-
58
- while len(frames) < max_frames:
59
- ret, frame = cap.read()
60
- if not ret:
61
- break
62
-
63
- if frame_count % frame_interval == 0:
64
- h, w = frame.shape[:2]
65
- if h > w:
66
- new_h, new_w = max_resolution, int(w * max_resolution / h)
67
- else:
68
- new_h, new_w = int(h * max_resolution / w), max_resolution
69
- frame = cv2.resize(frame, (new_w, new_h))
70
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
- frame = Image.fromarray(frame)
72
- frames.append(frame)
73
-
74
- frame_count += 1
75
-
76
- cap.release()
77
-
78
- messages = [
79
- {
80
- "role": "user",
81
- "content": [
82
- {"type": "video", "video": frames},
83
- {"type": "text", "text": "Describe this video."},
84
- ],
85
- }
86
- ]
87
-
88
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
- image_inputs, video_inputs = process_vision_info(messages)
90
-
91
- inputs = processor(
92
- text=[text],
93
- images=image_inputs,
94
- videos=video_inputs,
95
- padding=True,
96
- return_tensors="pt",
97
- ).to(model.device)
98
-
99
- with torch.no_grad():
100
- generated_ids = model.generate(**inputs, max_new_tokens=256)
101
- generated_ids_trimmed = [
102
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
103
- ]
104
- output_text = processor.batch_decode(
105
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
106
- )
107
-
108
- return output_text[0]
 
 
 
109
 
110
  def process_content(content):
111
  if content is None:
112
  return "Please upload an image or video file."
113
 
114
- if content.name.lower().endswith(('.png', '.jpg', '.jpeg')):
115
- return process_image(Image.open(content.name))
116
- elif content.name.lower().endswith(('.mp4', '.avi', '.mov')):
117
- return process_video(content.name)
118
- else:
119
- return "Unsupported file type. Please provide an image or video file."
 
 
 
120
 
121
  # Gradio interface
122
  iface = gr.Interface(
123
  fn=process_content,
124
  inputs=gr.File(label="Upload Image or Video"),
125
  outputs="text",
126
- title="Image and Video Description",
127
- description="Upload an image or video to get a description.",
128
  )
129
 
130
  if __name__ == "__main__":
131
- iface.launch(share=True)
 
6
  import numpy as np
7
  import gradio as gr
8
 
9
+ # Check GPU availability
10
+ if not torch.cuda.is_available():
11
+ raise RuntimeError("This application requires a GPU to run. No GPU detected.")
12
+
13
  # Load the model and processor
14
  def load_model():
15
+ try:
16
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
+ "Qwen/Qwen2-VL-2B-Instruct",
18
+ torch_dtype=torch.float16 # Use float16 for GPU
19
+ ).to("cuda") # Explicitly use CUDA
20
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
21
+ return model, processor
22
+ except RuntimeError as e:
23
+ print(f"Error loading model: {e}")
24
+ raise
25
+
26
+ try:
27
+ model, processor = load_model()
28
+ except Exception as e:
29
+ print(f"Failed to load model: {e}")
30
+ raise
31
 
32
  def process_image(image):
33
+ try:
34
+ messages = [
35
+ {
36
+ "role": "user",
37
+ "content": [
38
+ {"type": "image", "image": image},
39
+ {"type": "text", "text": "Describe this image."},
40
+ ],
41
+ }
42
+ ]
43
+
44
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+ image_inputs, video_inputs = process_vision_info(messages)
46
+
47
+ inputs = processor(
48
+ text=[text],
49
+ images=image_inputs,
50
+ videos=video_inputs,
51
+ padding=True,
52
+ return_tensors="pt",
53
+ ).to("cuda") # Explicitly use CUDA
54
+
55
+ with torch.no_grad():
56
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
57
+ generated_ids_trimmed = [
58
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
59
+ ]
60
+ output_text = processor.batch_decode(
61
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
62
+ )
63
+
64
+ return output_text[0]
65
+ except Exception as e:
66
+ return f"An error occurred while processing the image: {str(e)}"
67
 
68
  def process_video(video_path, max_frames=16, frame_interval=30, max_resolution=224):
69
+ try:
70
+ cap = cv2.VideoCapture(video_path)
71
+ frames = []
72
+ frame_count = 0
73
+
74
+ while len(frames) < max_frames:
75
+ ret, frame = cap.read()
76
+ if not ret:
77
+ break
78
+
79
+ if frame_count % frame_interval == 0:
80
+ h, w = frame.shape[:2]
81
+ if h > w:
82
+ new_h, new_w = max_resolution, int(w * max_resolution / h)
83
+ else:
84
+ new_h, new_w = int(h * max_resolution / w), max_resolution
85
+ frame = cv2.resize(frame, (new_w, new_h))
86
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
87
+ frame = Image.fromarray(frame)
88
+ frames.append(frame)
89
+
90
+ frame_count += 1
91
+
92
+ cap.release()
93
+
94
+ messages = [
95
+ {
96
+ "role": "user",
97
+ "content": [
98
+ {"type": "video", "video": frames},
99
+ {"type": "text", "text": "Describe this video."},
100
+ ],
101
+ }
102
+ ]
103
+
104
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
+ image_inputs, video_inputs = process_vision_info(messages)
106
+
107
+ inputs = processor(
108
+ text=[text],
109
+ images=image_inputs,
110
+ videos=video_inputs,
111
+ padding=True,
112
+ return_tensors="pt",
113
+ ).to("cuda") # Explicitly use CUDA
114
+
115
+ with torch.no_grad():
116
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
117
+ generated_ids_trimmed = [
118
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
119
+ ]
120
+ output_text = processor.batch_decode(
121
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
122
+ )
123
+
124
+ return output_text[0]
125
+ except Exception as e:
126
+ return f"An error occurred while processing the video: {str(e)}"
127
 
128
  def process_content(content):
129
  if content is None:
130
  return "Please upload an image or video file."
131
 
132
+ try:
133
+ if content.name.lower().endswith(('.png', '.jpg', '.jpeg')):
134
+ return process_image(Image.open(content.name))
135
+ elif content.name.lower().endswith(('.mp4', '.avi', '.mov')):
136
+ return process_video(content.name)
137
+ else:
138
+ return "Unsupported file type. Please provide an image or video file."
139
+ except Exception as e:
140
+ return f"An error occurred while processing the content: {str(e)}"
141
 
142
  # Gradio interface
143
  iface = gr.Interface(
144
  fn=process_content,
145
  inputs=gr.File(label="Upload Image or Video"),
146
  outputs="text",
147
+ title="Image and Video Description (GPU Version)",
148
+ description="Upload an image or video to get a description. This application requires GPU computation.",
149
  )
150
 
151
  if __name__ == "__main__":
152
+ iface.launch()