Spaces:
Running
on
Zero
Running
on
Zero
Update llava_llama3/serve/cli.py
Browse files- llava_llama3/serve/cli.py +44 -35
llava_llama3/serve/cli.py
CHANGED
@@ -26,58 +26,67 @@ def load_image(image_file):
|
|
26 |
return image
|
27 |
|
28 |
|
29 |
-
def chat_llava(args, image_file, text, tokenizer, model,
|
30 |
# Model
|
31 |
disable_torch_init()
|
32 |
|
33 |
conv = conv_templates[args.conv_mode].copy()
|
34 |
roles = conv.roles
|
35 |
|
36 |
-
print(f"\033[91m{image_file}, {type(image_file)}\033[0m")
|
37 |
-
image = load_image(image_file)
|
38 |
-
print(f"\033[91m{image}, {type(image)}\033[0m")
|
39 |
-
|
40 |
-
image_size = image.size
|
41 |
-
image_tensor = process_images([image], image_processor, model.config)
|
42 |
-
if type(image_tensor) is list:
|
43 |
-
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
44 |
-
else:
|
45 |
-
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
46 |
-
|
47 |
inp = text
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if model.config.mm_use_im_start_end:
|
54 |
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
55 |
else:
|
56 |
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
57 |
-
image = None
|
58 |
-
|
59 |
-
conv.append_message(conv.roles[0], inp)
|
60 |
-
conv.append_message(conv.roles[1], None)
|
61 |
-
prompt = conv.get_prompt()
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
outputs = tokenizer.decode(output_ids[0]).strip()
|
80 |
conv.messages[-1][-1] = outputs
|
81 |
|
82 |
# Return the model's output as a string
|
|
|
83 |
return outputs.replace('<|end_of_text|>', '\n').lstrip()
|
|
|
26 |
return image
|
27 |
|
28 |
|
29 |
+
def chat_llava(args, image_file, text, tokenizer, model, image_processor, context_len, streamer=None):
|
30 |
# Model
|
31 |
disable_torch_init()
|
32 |
|
33 |
conv = conv_templates[args.conv_mode].copy()
|
34 |
roles = conv.roles
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
inp = text
|
37 |
|
38 |
+
if image_file is not None:
|
39 |
+
print(image_file, type(image_file))
|
40 |
+
image = load_image(image_file)
|
41 |
+
print(image, type(image))
|
42 |
+
image_size = image.size
|
43 |
+
image_tensor = process_images([image], image_processor, model.config)
|
44 |
+
if type(image_tensor) is list:
|
45 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
46 |
+
else:
|
47 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
48 |
+
|
49 |
if model.config.mm_use_im_start_end:
|
50 |
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
51 |
else:
|
52 |
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
conv.append_message(conv.roles[0], inp)
|
55 |
+
conv.append_message(conv.roles[1], None)
|
56 |
+
prompt = conv.get_prompt()
|
57 |
+
|
58 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
59 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
60 |
+
keywords = [stop_str]
|
61 |
|
62 |
+
with torch.inference_mode():
|
63 |
+
output_ids = model.generate(
|
64 |
+
input_ids,
|
65 |
+
images=image_tensor,
|
66 |
+
image_sizes=[image_size],
|
67 |
+
do_sample=True if args.temperature > 0 else False,
|
68 |
+
temperature=args.temperature,
|
69 |
+
max_new_tokens=args.max_new_tokens,
|
70 |
+
streamer=streamer,
|
71 |
+
use_cache=True)
|
72 |
+
else:
|
73 |
+
conv.append_message(conv.roles[0], inp)
|
74 |
+
conv.append_message(conv.roles[1], None)
|
75 |
+
prompt = conv.get_prompt()
|
76 |
+
|
77 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
|
78 |
+
|
79 |
+
with torch.inference_mode():
|
80 |
+
output_ids = model.generate(
|
81 |
+
input_ids,
|
82 |
+
do_sample=True if args.temperature > 0 else False,
|
83 |
+
temperature=args.temperature,
|
84 |
+
max_new_tokens=args.max_new_tokens,
|
85 |
+
use_cache=True)
|
86 |
|
87 |
outputs = tokenizer.decode(output_ids[0]).strip()
|
88 |
conv.messages[-1][-1] = outputs
|
89 |
|
90 |
# Return the model's output as a string
|
91 |
+
# return outputs
|
92 |
return outputs.replace('<|end_of_text|>', '\n').lstrip()
|