whyumesh commited on
Commit
114c949
·
verified ·
1 Parent(s): b7f3d17

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import gradio as gr
8
+
9
+ # Load the model and processor
10
+ def load_model():
11
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
12
+ "Qwen/Qwen2-VL-2B-Instruct",
13
+ torch_dtype=torch.float16
14
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
15
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
16
+ return model, processor
17
+
18
+ model, processor = load_model()
19
+
20
+ def process_image(image):
21
+ messages = [
22
+ {
23
+ "role": "user",
24
+ "content": [
25
+ {"type": "image", "image": image},
26
+ {"type": "text", "text": "Describe this image."},
27
+ ],
28
+ }
29
+ ]
30
+
31
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
+ image_inputs, video_inputs = process_vision_info(messages)
33
+
34
+ inputs = processor(
35
+ text=[text],
36
+ images=image_inputs,
37
+ videos=video_inputs,
38
+ padding=True,
39
+ return_tensors="pt",
40
+ ).to(model.device)
41
+
42
+ with torch.no_grad():
43
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
44
+ generated_ids_trimmed = [
45
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
46
+ ]
47
+ output_text = processor.batch_decode(
48
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
49
+ )
50
+
51
+ return output_text[0]
52
+
53
+ def process_video(video_path, max_frames=16, frame_interval=30, max_resolution=224):
54
+ cap = cv2.VideoCapture(video_path)
55
+ frames = []
56
+ frame_count = 0
57
+
58
+ while len(frames) < max_frames:
59
+ ret, frame = cap.read()
60
+ if not ret:
61
+ break
62
+
63
+ if frame_count % frame_interval == 0:
64
+ h, w = frame.shape[:2]
65
+ if h > w:
66
+ new_h, new_w = max_resolution, int(w * max_resolution / h)
67
+ else:
68
+ new_h, new_w = int(h * max_resolution / w), max_resolution
69
+ frame = cv2.resize(frame, (new_w, new_h))
70
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
+ frame = Image.fromarray(frame)
72
+ frames.append(frame)
73
+
74
+ frame_count += 1
75
+
76
+ cap.release()
77
+
78
+ messages = [
79
+ {
80
+ "role": "user",
81
+ "content": [
82
+ {"type": "video", "video": frames},
83
+ {"type": "text", "text": "Describe this video."},
84
+ ],
85
+ }
86
+ ]
87
+
88
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
+ image_inputs, video_inputs = process_vision_info(messages)
90
+
91
+ inputs = processor(
92
+ text=[text],
93
+ images=image_inputs,
94
+ videos=video_inputs,
95
+ padding=True,
96
+ return_tensors="pt",
97
+ ).to(model.device)
98
+
99
+ with torch.no_grad():
100
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
101
+ generated_ids_trimmed = [
102
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
103
+ ]
104
+ output_text = processor.batch_decode(
105
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
106
+ )
107
+
108
+ return output_text[0]
109
+
110
+ def process_content(content):
111
+ if content is None:
112
+ return "Please upload an image or video file."
113
+
114
+ if content.name.lower().endswith(('.png', '.jpg', '.jpeg')):
115
+ return process_image(Image.open(content.name))
116
+ elif content.name.lower().endswith(('.mp4', '.avi', '.mov')):
117
+ return process_video(content.name)
118
+ else:
119
+ return "Unsupported file type. Please provide an image or video file."
120
+
121
+ # Gradio interface
122
+ iface = gr.Interface(
123
+ fn=process_content,
124
+ inputs=gr.File(label="Upload Image or Video"),
125
+ outputs="text",
126
+ title="Image and Video Description",
127
+ description="Upload an image or video to get a description.",
128
+ )
129
+
130
+ if __name__ == "__main__":
131
+ iface.launch()