VishalD1234 commited on
Commit
87014d0
·
verified ·
1 Parent(s): ab92603

Create process.py

Browse files
Files changed (1) hide show
  1. process.py +111 -0
process.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import spaces
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ from decord import cpu, VideoReader, bridge
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from transformers import BitsAndBytesConfig
9
+
10
+ MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
11
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
13
+ 0] >= 8 else torch.float16
14
+
15
+ parser = argparse.ArgumentParser(description="Apollo tyre delay reasoning")
16
+ parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=4)
17
+ args = parser.parse_args([])
18
+
19
+ def load_video(video_data, strategy='chat'):
20
+ bridge.set_bridge('torch')
21
+ mp4_stream = video_data
22
+ num_frames = 24
23
+ decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
24
+ frame_id_list = None
25
+ total_frames = len(decord_vr)
26
+
27
+ if strategy == 'base':
28
+ clip_end_sec = 60
29
+ clip_start_sec = 0
30
+ start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
31
+ end_frame = min(total_frames,
32
+ int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
33
+ frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
34
+ elif strategy == 'chat':
35
+ timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
36
+ timestamps = [i[0] for i in timestamps]
37
+ max_second = round(max(timestamps)) + 1
38
+ frame_id_list = []
39
+ for second in range(max_second):
40
+ closest_num = min(timestamps, key=lambda x: abs(x - second))
41
+ index = timestamps.index(closest_num)
42
+ frame_id_list.append(index)
43
+ if len(frame_id_list) >= num_frames:
44
+ break
45
+
46
+ video_data = decord_vr.get_batch(frame_id_list)
47
+ video_data = video_data.permute(3, 0, 1, 2)
48
+ return video_data
49
+
50
+ # Configure quantization
51
+ quantization_config = BitsAndBytesConfig(
52
+ load_in_4bit=True,
53
+ bnb_4bit_compute_dtype=TORCH_TYPE,
54
+ bnb_4bit_use_double_quant=True,
55
+ bnb_4bit_quant_type="nf4"
56
+ )
57
+
58
+ tokenizer = AutoTokenizer.from_pretrained(
59
+ MODEL_PATH,
60
+ trust_remote_code=True,
61
+ )
62
+
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ MODEL_PATH,
65
+ torch_dtype=TORCH_TYPE,
66
+ trust_remote_code=True,
67
+ quantization_config=quantization_config,
68
+ device_map="auto"
69
+ ).eval()
70
+
71
+ @spaces.GPU
72
+ def predict(prompt, video_data, temperature):
73
+ strategy = 'chat'
74
+ video = load_video(video_data, strategy=strategy)
75
+ history = []
76
+ query = prompt
77
+ inputs = model.build_conversation_input_ids(
78
+ tokenizer=tokenizer,
79
+ query=query,
80
+ images=[video],
81
+ history=history,
82
+ template_version=strategy
83
+ )
84
+
85
+ inputs = {
86
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
87
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
88
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
89
+ 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
90
+ }
91
+
92
+ gen_kwargs = {
93
+ "max_new_tokens": 2048,
94
+ "pad_token_id": 128002,
95
+ "top_k": 1,
96
+ "do_sample": False,
97
+ "top_p": 0.1,
98
+ "temperature": temperature,
99
+ }
100
+
101
+ with torch.no_grad():
102
+ outputs = model.generate(**inputs, **gen_kwargs)
103
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
104
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
105
+ return response
106
+
107
+ def inference(video, prompt):
108
+ temperature = 0.8
109
+ video_data = open(video, 'rb').read()
110
+ response = predict(prompt, video_data, temperature)
111
+ return response