TobyYang7 commited on
Commit
b8bd6a9
1 Parent(s): afff347

Update llava_llama3/serve/cli.py

Browse files
Files changed (1) hide show
  1. llava_llama3/serve/cli.py +44 -35
llava_llama3/serve/cli.py CHANGED
@@ -26,58 +26,67 @@ def load_image(image_file):
26
  return image
27
 
28
 
29
- def chat_llava(args, image_file, text, tokenizer, model, streamer, image_processor, context_len):
30
  # Model
31
  disable_torch_init()
32
 
33
  conv = conv_templates[args.conv_mode].copy()
34
  roles = conv.roles
35
 
36
- print(f"\033[91m{image_file}, {type(image_file)}\033[0m")
37
- image = load_image(image_file)
38
- print(f"\033[91m{image}, {type(image)}\033[0m")
39
-
40
- image_size = image.size
41
- image_tensor = process_images([image], image_processor, model.config)
42
- if type(image_tensor) is list:
43
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
44
- else:
45
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
46
-
47
  inp = text
48
 
49
- # print(f"{roles[1]}: ", end="")
50
-
51
- if image is not None:
52
- # first message
 
 
 
 
 
 
 
53
  if model.config.mm_use_im_start_end:
54
  inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
55
  else:
56
  inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
57
- image = None
58
-
59
- conv.append_message(conv.roles[0], inp)
60
- conv.append_message(conv.roles[1], None)
61
- prompt = conv.get_prompt()
62
 
63
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
64
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
65
- keywords = [stop_str]
66
- # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
67
 
68
- with torch.inference_mode():
69
- output_ids = model.generate(
70
- input_ids,
71
- images=image_tensor,
72
- image_sizes=[image_size],
73
- do_sample=True if args.temperature > 0 else False,
74
- temperature=args.temperature,
75
- max_new_tokens=args.max_new_tokens,
76
- streamer=streamer,
77
- use_cache=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  outputs = tokenizer.decode(output_ids[0]).strip()
80
  conv.messages[-1][-1] = outputs
81
 
82
  # Return the model's output as a string
 
83
  return outputs.replace('<|end_of_text|>', '\n').lstrip()
 
26
  return image
27
 
28
 
29
+ def chat_llava(args, image_file, text, tokenizer, model, image_processor, context_len, streamer=None):
30
  # Model
31
  disable_torch_init()
32
 
33
  conv = conv_templates[args.conv_mode].copy()
34
  roles = conv.roles
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  inp = text
37
 
38
+ if image_file is not None:
39
+ print(image_file, type(image_file))
40
+ image = load_image(image_file)
41
+ print(image, type(image))
42
+ image_size = image.size
43
+ image_tensor = process_images([image], image_processor, model.config)
44
+ if type(image_tensor) is list:
45
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
46
+ else:
47
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
48
+
49
  if model.config.mm_use_im_start_end:
50
  inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
51
  else:
52
  inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
 
 
 
 
 
53
 
54
+ conv.append_message(conv.roles[0], inp)
55
+ conv.append_message(conv.roles[1], None)
56
+ prompt = conv.get_prompt()
57
+
58
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
59
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
60
+ keywords = [stop_str]
61
 
62
+ with torch.inference_mode():
63
+ output_ids = model.generate(
64
+ input_ids,
65
+ images=image_tensor,
66
+ image_sizes=[image_size],
67
+ do_sample=True if args.temperature > 0 else False,
68
+ temperature=args.temperature,
69
+ max_new_tokens=args.max_new_tokens,
70
+ streamer=streamer,
71
+ use_cache=True)
72
+ else:
73
+ conv.append_message(conv.roles[0], inp)
74
+ conv.append_message(conv.roles[1], None)
75
+ prompt = conv.get_prompt()
76
+
77
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
78
+
79
+ with torch.inference_mode():
80
+ output_ids = model.generate(
81
+ input_ids,
82
+ do_sample=True if args.temperature > 0 else False,
83
+ temperature=args.temperature,
84
+ max_new_tokens=args.max_new_tokens,
85
+ use_cache=True)
86
 
87
  outputs = tokenizer.decode(output_ids[0]).strip()
88
  conv.messages[-1][-1] = outputs
89
 
90
  # Return the model's output as a string
91
+ # return outputs
92
  return outputs.replace('<|end_of_text|>', '\n').lstrip()