Thong Nguyen commited on
Commit
ca276a7
1 Parent(s): c1b1966

add mama code

Browse files
Files changed (3) hide show
  1. app.py +198 -5
  2. requirements.txt +8 -2
  3. video_keyframe_detector +1 -0
app.py CHANGED
@@ -1,16 +1,209 @@
1
  import gradio as gr
2
- import whisper
3
- import tempfile
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- model = whisper.load_model("base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def video_to_text(video_file):
10
  video_path = video_file.name
11
- transcription = model.transcribe(video_path)
12
 
13
- return transcription['text']
 
 
 
 
 
14
 
15
 
16
  iface = gr.Interface(
 
1
  import gradio as gr
2
+ import argparse
3
+ import shutil
4
  import os
5
+ from video_keyframe_detector.cli import keyframeDetection
6
+ import numpy as np
7
+ import cv2
8
+ from llava.constants import (
9
+ IMAGE_TOKEN_INDEX,
10
+ DEFAULT_IMAGE_TOKEN,
11
+ DEFAULT_IM_START_TOKEN,
12
+ DEFAULT_IM_END_TOKEN,
13
+ IMAGE_PLACEHOLDER,
14
+ )
15
+ from PIL import Image
16
+ from llava.conversation import conv_templates, SeparatorStyle
17
+ from llava.model.builder import load_pretrained_model
18
+ from llava.utils import disable_torch_init
19
+ from llava.mm_utils import (
20
+ process_images,
21
+ tokenizer_image_token,
22
+ get_model_name_from_path,
23
+ KeywordsStoppingCriteria,
24
+ )
25
+ import torch
26
+
27
+
28
+ def extract_keyframes(video_path, num_keyframes=12):
29
+ video_id = video_path.split('/')[-1].strip().split('.')[0]
30
+
31
+ os.makedirs("temp", exist_ok=True)
32
+
33
+ keyframeDetection(video_path, "temp", 0.6)
34
+ video_frame_list = sorted(os.listdir(os.path.join("temp", "keyFrames")), key=lambda x: int(x.split('.')[0][8:]))
35
+ os.makedirs(os.path.join("video_frames", video_id), exist_ok=True)
36
+ selected_frame_idx_set = set(np.linspace(1, len(video_frame_list) - 1, num_keyframes).astype(int))
37
+ cnt = 0
38
+ for i in range(len(video_frame_list)):
39
+ if i in selected_frame_idx_set:
40
+ source_file = os.path.join("temp", "keyFrames", video_frame_list[i])
41
+ target_file = os.path.join("video_frames", video_id, f"frame_{cnt}.jpg")
42
+ shutil.copyfile(source_file, target_file)
43
+ cnt += 1
44
+
45
+ shutil.rmtree("temp", ignore_errors=True)
46
+
47
+
48
+ def concatenate_frames(video_path):
49
+ os.makedirs("concatenated_frames", exist_ok=True)
50
+ video_id = video_path.split('/')[-1].strip().split('.')[0]
51
+ image_frame_dir = os.path.join("video_frames", video_id)
52
+ image_frame_list = sorted(os.listdir(os.path.join(image_frame_dir)), key=lambda x: int(x.split('.')[0].split('_')[1]))
53
+ img_list = []
54
+ for image_frame in image_frame_list:
55
+ img_frame = cv2.imread(os.path.join(image_frame_dir, image_frame))
56
+ img_list.append(img_frame)
57
+
58
+ img_row1 = cv2.hconcat(img_list[:4])
59
+ img_row2 = cv2.hconcat(img_list[4:8])
60
+ img_row3 = cv2.hconcat(img_list[8:12])
61
+
62
+ img_v = cv2.vconcat([img_row1, img_row2, img_row3])
63
+ cv2.imwrite(os.path.join("concatenated_frames", f"{video_id}.jpg"), img_v)
64
+
65
+
66
+ def image_parser(args):
67
+ out = args.image_file.split(args.sep)
68
+ return out
69
+
70
+
71
+ def load_image(image_file):
72
+ if image_file.startswith("http") or image_file.startswith("https"):
73
+ response = requests.get(image_file)
74
+ image = Image.open(BytesIO(response.content)).convert("RGB")
75
+ else:
76
+ image = Image.open(image_file).convert("RGB")
77
+ return image
78
+
79
+
80
+ def load_images(image_files):
81
+ out = []
82
+ for image_file in image_files:
83
+ image = load_image(image_file)
84
+ out.append(image)
85
+ return out
86
+
87
+ def eval_model(args, model_name, tokenizer, model, image_processor, context_len):
88
+ # Model
89
+ DEFAULT_IMAGE_TOKEN = "<image>"
90
+ disable_torch_init()
91
 
92
+ qs = args.query
93
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
94
+
95
+ if model.config.mm_use_im_start_end:
96
+ qs = image_token_se + "\n" + qs
97
+ else:
98
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
99
+
100
+ if "llama-2" in model_name.lower():
101
+ conv_mode = "llava_llama_2"
102
+ elif "v1" in model_name.lower():
103
+ conv_mode = "llava_v1"
104
+ elif "mpt" in model_name.lower():
105
+ conv_mode = "mpt"
106
+ else:
107
+ conv_mode = "llava_v0"
108
+
109
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
110
+ print(
111
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
112
+ conv_mode, args.conv_mode, args.conv_mode
113
+ )
114
+ )
115
+ else:
116
+ args.conv_mode = conv_mode
117
+
118
+ conv = conv_templates[args.conv_mode].copy()
119
+ conv.append_message(conv.roles[0], qs)
120
+ conv.append_message(conv.roles[1], None)
121
+ prompt = conv.get_prompt()
122
+
123
+ image_files = image_parser(args)
124
+ images = load_images(image_files)
125
+ images_tensor = process_images(
126
+ images,
127
+ image_processor,
128
+ model.config
129
+ ).to(model.device, dtype=torch.float16)
130
+
131
+ input_ids = (
132
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
133
+ .unsqueeze(0)
134
+ .cuda()
135
+ )
136
+
137
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
138
+ keywords = [stop_str]
139
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
140
+
141
+ with torch.inference_mode():
142
+ output_ids = model.generate(
143
+ input_ids,
144
+ images=images_tensor,
145
+ do_sample=True,
146
+ temperature=0.2,
147
+ max_new_tokens=1024,
148
+ use_cache=True,
149
+ stopping_criteria=[stopping_criteria],
150
+ )
151
+
152
+ input_token_len = input_ids.shape[1]
153
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
154
+ if n_diff_input_output > 0:
155
+ print(
156
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
157
+ )
158
+ outputs = tokenizer.batch_decode(
159
+ output_ids[:, input_token_len:], skip_special_tokens=True
160
+ )[0]
161
+ outputs = outputs.strip()
162
+ if outputs.endswith(stop_str):
163
+ outputs = outputs[: -len(stop_str)]
164
+ outputs = outputs.strip()
165
+ return outputs
166
+
167
+
168
+ def generate_video_caption(video_path):
169
+ model_path = "liuhaotian/llava-v1.5-7b"
170
+ model_name = get_model_name_from_path(model_path)
171
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
172
+ video_id = video_path.split('/')[-1].strip().split('.')[0]
173
+
174
+ image_file = os.path.join("concatenated_frames", f"{video_id}.jpg")
175
+ prompt = "In a short paragraph, describe the process in the video."
176
+
177
+ args = type('Args', (), {
178
+ "model_path": model_path,
179
+ "model_base": None,
180
+ "model_name": get_model_name_from_path(model_path),
181
+ "query": prompt,
182
+ "conv_mode": None,
183
+ "image_file": image_file,
184
+ "sep": ",",
185
+ "max_new_tokens": 1024,
186
+ "temperature": 0.2
187
+ })()
188
+
189
+ video_caption = eval_model(args, model_name, tokenizer, model, image_processor, context_len).replace("images", "video").replace("image", "video")
190
+ return video_caption
191
+
192
+
193
+ def clean_files_and_folders():
194
+ shutil.rmtree("concatenated_frames")
195
+ shutil.rmtree("video_frames")
196
 
197
 
198
  def video_to_text(video_file):
199
  video_path = video_file.name
 
200
 
201
+ extract_keyframes(video_path)
202
+ concatenate_frames(video_path)
203
+ video_caption = generate_video_caption(video_path)
204
+ clean_files_and_folders()
205
+
206
+ return video_caption
207
 
208
 
209
  iface = gr.Interface(
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
  gradio
2
- openai-whisper
3
- ffmpeg-python
 
 
 
 
 
 
 
1
  gradio
2
+ numpy==1.26.4
3
+ opencv-python
4
+ torch==2.1.2
5
+ torchvision==0.16.2
6
+ peakutils
7
+ matplotlib
8
+ protobuf
9
+ git+git://github.com/haotian-liu/LLaVA.git
video_keyframe_detector ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 5224a4f731ebe4277f8a04261e8268de9f9a077f