Martín Santillán Cooper commited on
Commit
fb6a6b8
·
1 Parent(s): 2b6005c
Files changed (1) hide show
  1. model.py +6 -5
model.py CHANGED
@@ -146,12 +146,7 @@ def parse_output_watsonx(generated_tokens_list):
146
  @spaces.GPU
147
  def generate_text(messages, criteria_name):
148
  logger.debug(f"Messages used to create the prompt are: \n{messages}")
149
-
150
  start = time()
151
-
152
- chat = get_prompt(messages, criteria_name)
153
- logger.debug(f"Prompt is \n{chat}")
154
-
155
  if inference_engine == "MOCK":
156
  logger.debug("Returning mocked model result.")
157
  sleep(1)
@@ -159,12 +154,15 @@ def generate_text(messages, criteria_name):
159
 
160
  elif inference_engine == "WATSONX":
161
  chat = get_prompt(messages, criteria_name)
 
162
  generated_tokens = generate_tokens(chat)
163
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
164
 
165
  elif inference_engine == "VLLM":
166
  input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
 
167
  input_len = input_ids.shape[1]
 
168
 
169
  with torch.no_grad():
170
  # output = model.generate(chat, sampling_params, use_tqdm=False)
@@ -174,8 +172,11 @@ def generate_text(messages, criteria_name):
174
  max_new_tokens=nlogprobs,
175
  return_dict_in_generate=True,
176
  output_scores=True,)
 
177
 
178
  label, prob_of_risk = parse_output(output, input_len)
 
 
179
  else:
180
  raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
181
 
 
146
  @spaces.GPU
147
  def generate_text(messages, criteria_name):
148
  logger.debug(f"Messages used to create the prompt are: \n{messages}")
 
149
  start = time()
 
 
 
 
150
  if inference_engine == "MOCK":
151
  logger.debug("Returning mocked model result.")
152
  sleep(1)
 
154
 
155
  elif inference_engine == "WATSONX":
156
  chat = get_prompt(messages, criteria_name)
157
+ logger.debug(f"Prompt is \n{chat}")
158
  generated_tokens = generate_tokens(chat)
159
  label, prob_of_risk = parse_output_watsonx(generated_tokens)
160
 
161
  elif inference_engine == "VLLM":
162
  input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
163
+ logger.debug(f"input_ids are: {input_ids}")
164
  input_len = input_ids.shape[1]
165
+ logger.debug(f"input_len are: {input_len}")
166
 
167
  with torch.no_grad():
168
  # output = model.generate(chat, sampling_params, use_tqdm=False)
 
172
  max_new_tokens=nlogprobs,
173
  return_dict_in_generate=True,
174
  output_scores=True,)
175
+ logger.debug(f"model output is are: {output}")
176
 
177
  label, prob_of_risk = parse_output(output, input_len)
178
+ logger.debug(f"label is are: {label}")
179
+ logger.debug(f"prob_of_risk is are: {prob_of_risk}")
180
  else:
181
  raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
182