soniakhamitkar commited on
Commit
df999ba
·
verified ·
1 Parent(s): 22347c4

create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ from decord import cpu, VideoReader, bridge
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
9
+
10
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
12
+
13
+ parser = argparse.ArgumentParser(description="CogVLM2 Video to Text")
14
+ parser.add_argument('--video', type=str, required=True, help="Path to the video file")
15
+ parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
16
+ args = parser.parse_args()
17
+
18
+
19
+ def load_video(video_path, strategy='chat'):
20
+ bridge.set_bridge('torch')
21
+
22
+ with open(video_path, 'rb') as f:
23
+ video_stream = f.read()
24
+
25
+ num_frames = 24
26
+ decord_vr = VideoReader(io.BytesIO(video_stream), ctx=cpu(0))
27
+
28
+ frame_id_list = None
29
+ total_frames = len(decord_vr)
30
+
31
+ if strategy == 'base':
32
+ clip_end_sec = 60
33
+ clip_start_sec = 0
34
+ start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
35
+ end_frame = min(total_frames, int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
36
+ frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
37
+
38
+ elif strategy == 'chat':
39
+ timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
40
+ timestamps = [i[0] for i in timestamps]
41
+ max_second = round(max(timestamps)) + 1
42
+ frame_id_list = []
43
+
44
+ for second in range(max_second):
45
+ closest_num = min(timestamps, key=lambda x: abs(x - second))
46
+ index = timestamps.index(closest_num)
47
+ frame_id_list.append(index)
48
+ if len(frame_id_list) >= num_frames:
49
+ break
50
+
51
+ video_data = decord_vr.get_batch(frame_id_list)
52
+ video_data = video_data.permute(3, 0, 1, 2)
53
+ return video_data
54
+
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(
57
+ MODEL_PATH,
58
+ trust_remote_code=True,
59
+ )
60
+
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ MODEL_PATH,
63
+ torch_dtype=TORCH_TYPE,
64
+ trust_remote_code=True
65
+ ).eval().to(DEVICE)
66
+
67
+
68
+ def predict(video_path, temperature=0.1):
69
+ strategy = 'chat'
70
+ prompt = "Please describe this video in detail."
71
+
72
+ video_data = load_video(video_path, strategy=strategy)
73
+
74
+ history = []
75
+ inputs = model.build_conversation_input_ids(
76
+ tokenizer=tokenizer,
77
+ query=prompt,
78
+ images=[video_data],
79
+ history=history,
80
+ template_version=strategy
81
+ )
82
+ inputs = {
83
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
84
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
85
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
86
+ 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
87
+ }
88
+ gen_kwargs = {
89
+ "max_new_tokens": 2048,
90
+ "pad_token_id": 128002,
91
+ "top_k": 1,
92
+ "do_sample": False,
93
+ "top_p": 0.1,
94
+ "temperature": temperature,
95
+ }
96
+ with torch.no_grad():
97
+ outputs = model.generate(**inputs, **gen_kwargs)
98
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
99
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+ return response
101
+
102
+
103
+ if __name__ == '__main__':
104
+ video_file = args.video
105
+ response_text = predict(video_file)
106
+ print("\nGenerated Text Description:\n")
107
+ print(response_text)