lorocksUMD commited on
Commit
a3de5d2
·
verified ·
1 Parent(s): 43605d2

Create multi_script.py

Browse files
Files changed (1) hide show
  1. multi_script.py +168 -0
multi_script.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ from transformers import AutoTokenizer
5
+ from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
6
+ from llava.model.builder import load_pretrained_model
7
+ from llava.mm_utils import (
8
+ process_images,
9
+ tokenizer_image_token,
10
+ get_model_name_from_path,
11
+ )
12
+ from llava.constants import (
13
+ IMAGE_TOKEN_INDEX,
14
+ DEFAULT_IMAGE_TOKEN,
15
+ DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IM_END_TOKEN,
17
+ IMAGE_PLACEHOLDER,
18
+ )
19
+ from llava.conversation import conv_templates, SeparatorStyle
20
+
21
+ import argparse
22
+ import torch
23
+ import requests
24
+ from PIL import Image
25
+ from io import BytesIO
26
+ import re
27
+
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--model-path", type=str, default="liuhaotian/llava-v1.6-mistral-7b")
30
+ parser.add_argument("--image-file", type=str, required=True)
31
+ parser.add_argument("--inference-type", type=str, default="auto")
32
+ parser.add_argument("--prompt", type=str, default="Explain this image")
33
+ cmd_args = parser.parse_args()
34
+
35
+ # Line 138 uncomment the cuda() to use GPUs
36
+
37
+ # device = "cpu"
38
+ device = cmd_args.inference_type
39
+
40
+ prompt = cmd_args.prompt
41
+ image_file = cmd_args.image_file
42
+
43
+ model_path = cmd_args.model_path
44
+
45
+
46
+
47
+ # Functions for inference
48
+ def image_parser(args):
49
+ out = args.image_file.split(args.sep)
50
+ return out
51
+
52
+
53
+ def load_image(image_file):
54
+ if image_file.startswith("http") or image_file.startswith("https"):
55
+ response = requests.get(image_file)
56
+ image = Image.open(BytesIO(response.content)).convert("RGB")
57
+ else:
58
+ image = Image.open(image_file).convert("RGB")
59
+ return image
60
+
61
+
62
+ def load_images(image_files):
63
+ out = []
64
+ for image_file in image_files:
65
+ image = load_image(image_file)
66
+ out.append(image)
67
+ return out
68
+
69
+
70
+ model_name = get_model_name_from_path('llava-v1.6-mistral-7b')
71
+
72
+ args = type('Args', (), {
73
+ "model_path": model_path,
74
+ "model_base": None,
75
+ "model_name": model_name,
76
+ "query": prompt,
77
+ "conv_mode": None,
78
+ "image_file": image_file,
79
+ "sep": ",",
80
+ "temperature": 0,
81
+ "top_p": None,
82
+ "num_beams": 1,
83
+ "max_new_tokens": 512
84
+ })()
85
+
86
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
87
+ model_path, None, model_name, device_map=device
88
+ )
89
+
90
+ qs = args.query
91
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
92
+ if IMAGE_PLACEHOLDER in qs:
93
+ if model.config.mm_use_im_start_end:
94
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
95
+ else:
96
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
97
+ else:
98
+ if model.config.mm_use_im_start_end:
99
+ qs = image_token_se + "\n" + qs
100
+ else:
101
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
102
+
103
+ if "llama-2" in model_name.lower():
104
+ conv_mode = "llava_llama_2"
105
+ elif "mistral" in model_name.lower():
106
+ conv_mode = "mistral_instruct"
107
+ elif "v1.6-34b" in model_name.lower():
108
+ conv_mode = "chatml_direct"
109
+ elif "v1" in model_name.lower():
110
+ conv_mode = "llava_v1"
111
+ elif "mpt" in model_name.lower():
112
+ conv_mode = "mpt"
113
+ else:
114
+ conv_mode = "llava_v0"
115
+
116
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
117
+ print(
118
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
119
+ conv_mode, args.conv_mode, args.conv_mode
120
+ )
121
+ )
122
+ else:
123
+ args.conv_mode = conv_mode
124
+
125
+ conv = conv_templates[args.conv_mode].copy()
126
+ conv.append_message(conv.roles[0], qs)
127
+ conv.append_message(conv.roles[1], None)
128
+ prompt = conv.get_prompt()
129
+
130
+ image_files = image_parser(args)
131
+ images = load_images(image_files)
132
+ image_sizes = [x.size for x in images]
133
+ images_tensor = process_images(
134
+ images,
135
+ image_processor,
136
+ model.config
137
+ ).to(model.device, dtype=torch.float16)
138
+
139
+ input_ids = (
140
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
141
+ .unsqueeze(0)
142
+ # .cuda()
143
+ )
144
+
145
+ with torch.inference_mode():
146
+ output_ids = model.generate(
147
+ input_ids,
148
+ images=images_tensor,
149
+ image_sizes=image_sizes,
150
+ do_sample=True if args.temperature > 0 else False,
151
+ temperature=args.temperature,
152
+ top_p=args.top_p,
153
+ num_beams=args.num_beams,
154
+ max_new_tokens=args.max_new_tokens,
155
+ use_cache=True,
156
+ )
157
+
158
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
159
+
160
+ if "dataset1" in image_file:
161
+ print("Num of words: ", len(outputs))
162
+ elif "dataset2" in image_file:
163
+ print()
164
+ else:
165
+ print("Is single word?", len((outputs).split()) == 1)
166
+
167
+ print(outputs)
168
+ # End Llava inference